diff --git a/WORKSPACE b/WORKSPACE index eb3efd275..1bd865c1e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -361,6 +361,19 @@ http_archive( ], ) +# PyTorch archives + +# 2020-03-12 +# On https://pytorch.org: Stable > Linux > LibTorch > C++ > (CUDA) None > cxx11 ABI +http_archive( + name = "linux_libtorch_cpu", + build_file = "@//third_party:libtorch_linux.BUILD", + sha256 = "33a9dd142d0497375db42b055bd90780f9d92047a19edc8891e6232e2b5bdba7", + strip_prefix = "libtorch", + type = "zip", + url = "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip", +) + #Tensorflow repo should always go after the other external dependencies. # 2020-08-30 _TENSORFLOW_GIT_COMMIT = "57b009e31e59bd1a7ae85ef8c0232ed86c9b71db" diff --git a/docs/solutions/object_classification.md b/docs/solutions/object_classification.md new file mode 100644 index 000000000..e5c69a29d --- /dev/null +++ b/docs/solutions/object_classification.md @@ -0,0 +1,79 @@ +--- +layout: default +title: Object Classification +parent: Solutions +nav_order: TODO +--- + +# MediaPipe Object Classification +{: .no_toc } + +1. TOC +{:toc} +--- + +## Example Apps + +Note: To visualize a graph, copy the graph and paste it into +[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how +to visualize its associated subgraphs, please see +[visualizer documentation](../tools/visualizer.md). + + + +### Desktop + +#### Live Camera Input + +Please first see general instructions for +[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples. + +* Graph: + [`mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt) +* Target: + [`mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/object_classification/BUILD) + +#### Video File Input + +* With a PyTorch Model + + This uses a MobileNetv2 trace model from PyTorch Hub. To fetch and prepare it, run: + + ```bash + python mediapipe/models/trace_mobilenetv2.py + ``` + + The pipeline is implemented in this + [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt). + + To build the application, run: + + ```bash + bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu + ``` + + To run the application, replace `` and `` in the command below with your own paths: + + Tip: You can find a test video available in + `mediapipe/examples/desktop/object_detection`. + + ``` + GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_classification/object_classification_pytorch_cpu \ + --calculator_graph_config_file=mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt \ + --input_side_packets=input_video_path=,output_video_path= + ``` + \ No newline at end of file diff --git a/mediapipe/calculators/pytorch/BUILD b/mediapipe/calculators/pytorch/BUILD new file mode 100644 index 000000000..8880fca5d --- /dev/null +++ b/mediapipe/calculators/pytorch/BUILD @@ -0,0 +1,138 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_cc//cc:defs.bzl", "cc_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +proto_library( + name = "pytorch_converter_calculator_proto", + srcs = ["pytorch_converter_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "pytorch_inference_calculator_proto", + srcs = ["pytorch_inference_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +proto_library( + name = "pytorch_tensors_to_classification_calculator_proto", + srcs = ["pytorch_tensors_to_classification_calculator.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "pytorch_converter_calculator_cc_proto", + srcs = ["pytorch_converter_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":pytorch_converter_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "pytorch_inference_calculator_cc_proto", + srcs = ["pytorch_inference_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":pytorch_inference_calculator_proto"], +) + +mediapipe_cc_proto_library( + name = "pytorch_tensors_to_classification_calculator_cc_proto", + srcs = ["pytorch_tensors_to_classification_calculator.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":pytorch_tensors_to_classification_calculator_proto"], +) + +cc_library( + name = "pytorch_converter_calculator", + srcs = ["pytorch_converter_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":pytorch_converter_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//third_party:libtorch", + ] + select({ + "//mediapipe:ios": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/objc:util", + ], + "//conditions:default": [], + }), + alwayslink = 1, +) + +cc_library( + name = "pytorch_inference_calculator", + srcs = ["pytorch_inference_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":pytorch_inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/util:resource_util", + "//third_party:libtorch", + ], + alwayslink = 1, +) + +cc_library( + name = "pytorch_tensors_to_classification_calculator", + srcs = ["pytorch_tensors_to_classification_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":pytorch_tensors_to_classification_calculator_cc_proto", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/port:status", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:location", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:resource_util", + "//third_party:libtorch", + ] + select({ + "//mediapipe:android": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:apple": [ + "//mediapipe/util/android/file/base", + ], + "//mediapipe:macos": [ + "//mediapipe/framework/port:file_helpers", + ], + "//conditions:default": [ + "//mediapipe/framework/port:file_helpers", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc new file mode 100644 index 000000000..b2696f73f --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc @@ -0,0 +1,179 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "torch/script.h" +#include "torch/torch.h" + +#if defined(MEDIAPIPE_IOS) +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/objc/util.h" +#endif // iOS + +namespace mediapipe { + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageGpuTag[] = "IMAGE_GPU"; +constexpr char kTensorsTag[] = "TENSORS"; + +using Outputs = std::vector; +} // namespace + +// Calculator for normalizing and converting an ImageFrame +// into a PyTorchTensor. +// +// This calculator is designed to be used with the PyTorchInferenceCalculator, +// as a pre-processing step for calculator inputs. +// +// IMAGE inputs are normalized to [-1,1] (default) or [0,1], +// specified by options (unless outputting a quantized tensor). +// +// Input: +// One of the following tags: +// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data). +// +// Output: +// One of the following tags: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8. +// +// Example use: +// node { +// calculator: "PyTorchConverterCalculator" +// input_stream: "IMAGE:input_image" +// output_stream: "TENSORS:image_tensor" +// options: { +// [mediapipe.PyTorchConverterCalculatorOptions.ext] { +// zero_center: true +// } +// } +// } +// +// IMPORTANT Notes: +// This calculator uses FixedSizeInputStreamHandler by default. +// +class PyTorchConverterCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::PyTorchConverterCalculatorOptions options_; + bool has_image_tag_; + bool has_image_gpu_tag_; +}; +REGISTER_CALCULATOR(PyTorchConverterCalculator); + +::mediapipe::Status PyTorchConverterCalculator::GetContract( + CalculatorContract* cc) { + const bool has_image_tag = cc->Inputs().HasTag(kImageTag); + const bool has_image_gpu_tag = cc->Inputs().HasTag(kImageGpuTag); + // Confirm only one of the input streams is present. + RET_CHECK(has_image_tag ^ has_image_gpu_tag); + + const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag); + RET_CHECK(has_tensors_tag); + + if (has_image_tag) cc->Inputs().Tag(kImageTag).Set(); + if (has_image_gpu_tag) { +#if defined(MEDIAPIPE_IOS) + cc->Inputs().Tag(kImageGpuTag).Set(); +#else + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif + } + + if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set(); + + // Assign this calculator's default InputStreamHandler. + cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchConverterCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options<::mediapipe::PyTorchConverterCalculatorOptions>(); + + has_image_tag_ = cc->Inputs().HasTag(kImageTag); + has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag); + + if (has_image_gpu_tag_) { +#if !defined(MEDIAPIPE_IOS) + RET_CHECK_FAIL() << "GPU processing not enabled."; +#endif + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) { + cv::Mat image; + if (has_image_gpu_tag_) { +#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + const auto& input = cc->Inputs().Tag(kImageGpuTag).Get(); + std::unique_ptr frame = + CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef()); + image = mediapipe::formats::MatView(frame.release()); +#else + RET_CHECK_FAIL() << "GPU processing is not enabled."; +#endif + } + if (has_image_tag_) { + auto& output_frame = + cc->Inputs().Tag(kImageTag).Get(); + image = mediapipe::formats::MatView(&output_frame); + } + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); + const auto width = image.cols, height = image.rows, + num_channels = image.channels(); + RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now"; + + cv::Mat img_float; + image.convertTo(img_float, CV_32F, 1.0 / 255); + auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3}); + img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH + if (options_.per_channel_normalizations().size() > 0) + for (int i = 0; i < num_channels; ++i) { + const auto& subdiv = options_.per_channel_normalizations(i); + img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div()); + } + + auto output_tensors = absl::make_unique(); + output_tensors->reserve(1); + output_tensors->emplace_back(img_tensor.cpu()); + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) { + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto new file mode 100644 index 000000000..22b3ac9c3 --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto @@ -0,0 +1,34 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PyTorchConverterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional PyTorchConverterCalculatorOptions ext = 245877797; + } + + message SubDiv { + optional float sub = 1 [default = 0]; + optional float div = 2 [default = 1]; + } + // Normalizations to apply per input image channel. + // There must be exactly as many items in this list as there are channels + // or none. + repeated SubDiv per_channel_normalizations = 1; +} diff --git a/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc b/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc new file mode 100644 index 000000000..7573747a9 --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc @@ -0,0 +1,172 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/resource_util.h" +#include "torch/script.h" +#include "torch/torch.h" + +namespace mediapipe { + +namespace { +constexpr char kTensorsTag[] = "TENSORS"; + +using Inputs = std::vector; +using Outputs = torch::Tensor; +} // namespace + +// Calculator Header Section + +// Runs inference on the provided input TFLite tensors and TFLite model. +// +// Creates an interpreter with given model and calls invoke(). +// Optionally run inference on CPU/GPU. +// +// This calculator is designed to be used with the TfLiteConverterCalcualtor, +// to get the appropriate inputs. +// +// When the input tensors are on CPU, gpu inference is optional and can be +// specified in the calculator options. +// +// Input: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8 +// +// Output: +// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8 +// +// Example use: +// node { +// calculator: "PyTorchInferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.PyTorchInferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// delegate { gpu {} } +// } +// } +// } +// +// IMPORTANT Notes: +// This calculator uses FixedSizeInputStreamHandler by default. +// +class PyTorchInferenceCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::PyTorchInferenceCalculatorOptions options_; + torch::jit::script::Module module_; + torch::jit::IValue hidden_state_; +}; +REGISTER_CALCULATOR(PyTorchInferenceCalculator); + +// Calculator Core Section + +::mediapipe::Status PyTorchInferenceCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); + RET_CHECK(cc->Outputs().HasTag(kTensorsTag)); + + if (cc->Inputs().HasTag(kTensorsTag)) + cc->Inputs().Tag(kTensorsTag).Set(); + + if (cc->Outputs().HasTag(kTensorsTag)) + cc->Outputs().Tag(kTensorsTag).Set(); + + // Assign this calculator's default InputStreamHandler. + cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchInferenceCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options<::mediapipe::PyTorchInferenceCalculatorOptions>(); + + std::string model_path = options_.model_path(); + ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path)); + try { + // https://github.com/pytorch/ios-demo-app/issues/8#issuecomment-612996683 + auto qengines = at::globalContext().supportedQEngines(); + if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != + qengines.end()) { + LOG(INFO) << "Using QEngine at::QEngine::QNNPACK"; + at::globalContext().setQEngine(at::QEngine::QNNPACK); + } +#if defined(MEDIAPIPE_IOS) + else { + RET_CHECK_FAIL() << "QEngine::QNNPACK is required for iOS"; + } +#endif + + module_ = torch::jit::load(model_path); + module_.eval(); + } catch (const std::exception& e) { + LOG(ERROR) << e.what(); + return ::mediapipe::UnknownError(e.what()); + } + + if (options_.model_has_hidden_state()) + hidden_state_ = torch::zeros({1, 1, 10}); // TODO: read options + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchInferenceCalculator::Process(CalculatorContext* cc) { + const auto inputs = cc->Inputs().Tag(kTensorsTag).Get(); + RET_CHECK_GT(inputs.size(), 0); + + // Disables autograd + torch::autograd::AutoGradMode guard(false); + // Disables autograd even more? https://github.com/pytorch/pytorch/pull/26477 + at::AutoNonVariableTypeMode non_var_type_mode(true); + + Outputs out_tensor; + if (options_.model_has_hidden_state()) { + RET_CHECK_EQ(inputs.size(), 1) << "Not sure how to forward() hidden state"; + // auto tuple = torch::ivalue::Tuple::create({inp,hidden_state_}); + const auto result = module_.forward({{inputs[0], hidden_state_}}); + const auto out = result.toTuple()->elements(); + out_tensor = out[0].toTensor(); + hidden_state_ = out[1].toTensor(); + } else { + const auto result = module_.forward(std::move(inputs)); + out_tensor = result.toTensor(); + } + + auto out = absl::make_unique(out_tensor); + cc->Outputs().Tag(kTensorsTag).Add(out.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchInferenceCalculator::Close(CalculatorContext* cc) { + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto b/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto new file mode 100644 index 000000000..8a93d2608 --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto @@ -0,0 +1,71 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PyTorchInferenceCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional PyTorchInferenceCalculatorOptions ext = 233877213; + } + + // message Delegate { + // // Default inference provided by tflite. + // message TfLite {} + // // Delegate to run GPU inference depending on the device. + // // (Can use OpenGl, OpenCl, Metal depending on the device.) + // message Gpu {} + // // Android only. + // message Nnapi {} + + // oneof delegate { + // TfLite tflite = 1; + // Gpu gpu = 2; + // Nnapi nnapi = 3; + // } + // } + + // Path to the PyTorch trace model (ex: /path/to/modelname.pt). + optional string model_path = 1; + + // Whether model has hidden state. + optional bool model_has_hidden_state = 2; + + // // Whether the TF Lite GPU or CPU backend should be used. Effective only + // when + // // input tensors are on CPU. For input tensors on GPU, GPU backend is + // always + // // used. + // // DEPRECATED: configure "delegate" instead. + // optional bool use_gpu = 2 [deprecated = true, default = false]; + + // // Android only. When true, an NNAPI delegate will be used for inference. + // // If NNAPI is not available, then the default CPU delegate will be used + // // automatically. + // // DEPRECATED: configure "delegate" instead. + // optional bool use_nnapi = 3 [deprecated = true, default = false]; + + // // The number of threads available to the interpreter. Effective only when + // // input tensors are on CPU and 'use_gpu' is false. + // optional int32 cpu_num_thread = 4 [default = -1]; + + // // TfLite delegate to run inference. + // // NOTE: calculator is free to choose delegate if not specified explicitly. + // // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes + // // precedence over use_* deprecated options.) + // optional Delegate delegate = 5; +} diff --git a/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc new file mode 100644 index 000000000..b04b9a881 --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc @@ -0,0 +1,173 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/resource_util.h" +#include "torch/torch.h" +#if defined(MEDIAPIPE_MOBILE) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/helpers.h" +#else +#include "mediapipe/framework/port/file_helpers.h" +#endif + +namespace mediapipe { + +namespace { +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kClassificationListTag[] = "CLASSIFICATION_LIST"; + +using Inputs = torch::Tensor; +} // namespace + +// Convert result PyTorch tensors from classification models into MediaPipe +// classifications. +// +// Input: +// TENSORS - Vector of PyTorch of type kTfLiteFloat32 containing one +// tensor, the size of which must be (1, * num_classes). +// Output: +// CLASSIFICATIONS - Result MediaPipe ClassificationList. The score and index +// fields of each classification are set, while the label +// field is only set if label_map_path is provided. +// +// Usage example: +// node { +// calculator: "PyTorchTensorsToClassificationCalculator" +// input_stream: "TENSORS:tensors" +// output_stream: "CLASSIFICATIONS:classifications" +// options: { +// [mediapipe.PyTorchTensorsToClassificationCalculatorOptions.ext] { +// num_classes: 1024 +// min_score_threshold: 0.1 +// label_map_path: "labelmap.txt" +// } +// } +// } +class PyTorchTensorsToClassificationCalculator : public CalculatorBase { + public: + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + ::mediapipe::Status Process(CalculatorContext* cc) override; + ::mediapipe::Status Close(CalculatorContext* cc) override; + + private: + ::mediapipe::PyTorchTensorsToClassificationCalculatorOptions options_; + std::unordered_map label_map_; + bool label_map_loaded_ = false; +}; +REGISTER_CALCULATOR(PyTorchTensorsToClassificationCalculator); + +::mediapipe::Status PyTorchTensorsToClassificationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kTensorsTag)); + cc->Inputs().Tag(kTensorsTag).Set(); + + RET_CHECK(!cc->Outputs().GetTags().empty()); + if (cc->Outputs().HasTag(kClassificationListTag)) { + cc->Outputs().Tag(kClassificationListTag).Set(); + } + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchTensorsToClassificationCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options< + ::mediapipe::PyTorchTensorsToClassificationCalculatorOptions>(); + + if (options_.has_label_map_path()) { + std::string string_path; + ASSIGN_OR_RETURN(string_path, + PathToResourceAsFile(options_.label_map_path())); + std::string label_map_string; + MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); + + std::istringstream stream(label_map_string); + std::string line; + int i = 0; + while (std::getline(stream, line)) label_map_[i++] = line; + label_map_loaded_ = true; + } + + if (options_.has_top_k()) RET_CHECK_GT(options_.top_k(), 0); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchTensorsToClassificationCalculator::Process( + CalculatorContext* cc) { + const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get(); + RET_CHECK_EQ(input_tensors.dim(), 2); + + std::tuple result = + input_tensors.sort(/*dim*/ -1, /*descending*/ true); + const torch::Tensor scores_tensor = std::get<0>(result)[0]; + const torch::Tensor indices_tensor = + std::get<1>(result)[0].toType(torch::kInt32); + + auto scores = scores_tensor.accessor(); + auto indices = indices_tensor.accessor(); + + const auto indices_count = indices.size(0); + RET_CHECK_EQ(indices_count, scores.size(0)); + if (label_map_loaded_) + RET_CHECK_EQ(indices_count, label_map_.size()) + << "need: " << indices_count << ", got: " << label_map_.size(); + + // RET_CHECK_GE(indices_count, options_.top_k()); + auto top_k = indices.size(0); + if (options_.has_top_k()) top_k = options_.top_k(); + + auto classification_list = absl::make_unique(); + for (int i = 0; i < indices_count; ++i) { + if (classification_list->classification_size() == top_k) break; + const float score = scores[i]; + const int index = indices[i]; + if (options_.has_min_score_threshold() && + score < options_.min_score_threshold()) + continue; + + Classification* classification = classification_list->add_classification(); + classification->set_score(score); + classification->set_index(index); + if (label_map_loaded_) classification->set_label(label_map_[index]); + } + + cc->Outputs() + .Tag(kClassificationListTag) + .Add(classification_list.release(), cc->InputTimestamp()); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status PyTorchTensorsToClassificationCalculator::Close( + CalculatorContext* cc) { + return ::mediapipe::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto new file mode 100644 index 000000000..efac275c6 --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto @@ -0,0 +1,33 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message PyTorchTensorsToClassificationCalculatorOptions { + extend .mediapipe.CalculatorOptions { + optional PyTorchTensorsToClassificationCalculatorOptions ext = 266379463; + } + + // Score threshold for perserving the class. + optional float min_score_threshold = 1; + // Number of highest scoring labels to output. If top_k is not positive then + // all labels are used. + optional int32 top_k = 2; + // Path to a label map file for getting the actual name of class ids. + optional string label_map_path = 3; +} diff --git a/mediapipe/examples/desktop/object_classification/BUILD b/mediapipe/examples/desktop/object_classification/BUILD new file mode 100644 index 000000000..3571cfa8a --- /dev/null +++ b/mediapipe/examples/desktop/object_classification/BUILD @@ -0,0 +1,27 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_binary") + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +cc_binary( + name = "object_classification_pytorch_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/object_classification:desktop_pytorch_calculators", + ], +) diff --git a/mediapipe/graphs/object_classification/BUILD b/mediapipe/graphs/object_classification/BUILD new file mode 100644 index 000000000..dc1552659 --- /dev/null +++ b/mediapipe/graphs/object_classification/BUILD @@ -0,0 +1,32 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "desktop_pytorch_calculators", + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/pytorch:pytorch_converter_calculator", + "//mediapipe/calculators/pytorch:pytorch_inference_calculator", + "//mediapipe/calculators/pytorch:pytorch_tensors_to_classification_calculator", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:labels_to_render_data_calculator", + ], +) diff --git a/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt b/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt new file mode 100644 index 000000000..f4b1524e7 --- /dev/null +++ b/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt @@ -0,0 +1,120 @@ +# MediaPipe graph that performs object classification with PyTorch on CPU. +# Used in the examples in +# mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu. + +# Images on CPU coming into and out of the graph. +input_stream: "input_video" +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for +# PyTorchTensorsToClassificationCalculator downstream in the graph to finish +# generating the corresponding classifications before it passes through another +# image. All images that come in while waiting are dropped, limiting the number +# of in-flight images between this calculator and +# PyTorchTensorsToClassificationCalculator to 1. This prevents the nodes in between +# from queuing up incoming images and data excessively, which leads to increased +# latency and memory usage, unwanted in real-time mobile applications. It also +# eliminates unnecessarily computation, e.g., a transformed image produced by +# ImageTransformationCalculator may get dropped downstream if the subsequent +# PyTorchConverterCalculator or PyTorchInferenceCalculator is still busy +# processing previous inputs. +node { + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:output_classifications" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" +} + +# Transforms the input image on CPU to a 224x224 image. To scale the input +# image, the scale_mode option is set to FIT to preserve the aspect ratio, +# resulting in potential letterboxing in the transformed image. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video" + output_stream: "IMAGE:transformed_input_video" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 224 + output_height: 224 + scale_mode: FIT + } + } +} + +# Converts the transformed input image on CPU into an image tensor stored as a +# PyTorch tensor. Pixel values are normalized using mean = [0.485, 0.456, 0.406] +# and std = [0.229, 0.224, 0.225]. +node { + calculator: "PyTorchConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.PyTorchConverterCalculatorOptions] { + per_channel_normalizations: {sub:0.485 div:0.229} + per_channel_normalizations: {sub:0.456 div:0.224} + per_channel_normalizations: {sub:0.406 div:0.225} + } + } +} + +# Runs a PyTorch model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, classification scores. +node { + calculator: "PyTorchInferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:classification_tensors" + node_options: { + [type.googleapis.com/mediapipe.PyTorchInferenceCalculatorOptions] { + model_path: "mediapipe/models/mobilenetv2.pt" + } + } +} + +# Decodes the classifications tensors generated by the PyTorch model, based on +# the specification in the options, into a vector of classifications. +# Maps classification label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "PyTorchTensorsToClassificationCalculator" + input_stream: "TENSORS:classification_tensors" + output_stream: "CLASSIFICATION_LIST:output_classifications" + node_options: { + [type.googleapis.com/mediapipe.PyTorchTensorsToClassificationCalculatorOptions] { + top_k: 3 + min_score_threshold: 0.1 + label_map_path: "mediapipe/models/mobilenetv2.labelmap" + } + } +} + +# Converts the classifications label to drawing primitives for annotation overlay. +node { + calculator: "LabelsToRenderDataCalculator" + input_stream: "CLASSIFICATIONS:output_classifications" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions] { + color { r: 255 g: 0 b: 0 } + color { r: 0 g: 255 b: 0 } + color { r: 0 g: 0 b: 255 } + thickness: 2.0 + font_height_px: 20 + font_face: 1 + location: TOP_LEFT + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "IMAGE:throttled_input_video" + input_stream: "render_data" + output_stream: "IMAGE:output_video" +} diff --git a/mediapipe/models/.gitignore b/mediapipe/models/.gitignore new file mode 100644 index 000000000..5f849ac87 --- /dev/null +++ b/mediapipe/models/.gitignore @@ -0,0 +1 @@ +/mobilenetv2.pt diff --git a/mediapipe/models/mobilenetv2.labelmap b/mediapipe/models/mobilenetv2.labelmap new file mode 100644 index 000000000..a509c0074 --- /dev/null +++ b/mediapipe/models/mobilenetv2.labelmap @@ -0,0 +1,1000 @@ +tench, Tinca tinca +goldfish, Carassius auratus +great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +tiger shark, Galeocerdo cuvieri +hammerhead, hammerhead shark +electric ray, crampfish, numbfish, torpedo +stingray +cock +hen +ostrich, Struthio camelus +brambling, Fringilla montifringilla +goldfinch, Carduelis carduelis +house finch, linnet, Carpodacus mexicanus +junco, snowbird +indigo bunting, indigo finch, indigo bird, Passerina cyanea +robin, American robin, Turdus migratorius +bulbul +jay +magpie +chickadee +water ouzel, dipper +kite +bald eagle, American eagle, Haliaeetus leucocephalus +vulture +great grey owl, great gray owl, Strix nebulosa +European fire salamander, Salamandra salamandra +common newt, Triturus vulgaris +eft +spotted salamander, Ambystoma maculatum +axolotl, mud puppy, Ambystoma mexicanum +bullfrog, Rana catesbeiana +tree frog, tree-frog +tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +loggerhead, loggerhead turtle, Caretta caretta +leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +mud turtle +terrapin +box turtle, box tortoise +banded gecko +common iguana, iguana, Iguana iguana +American chameleon, anole, Anolis carolinensis +whiptail, whiptail lizard +agama +frilled lizard, Chlamydosaurus kingi +alligator lizard +Gila monster, Heloderma suspectum +green lizard, Lacerta viridis +African chameleon, Chamaeleo chamaeleon +Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +African crocodile, Nile crocodile, Crocodylus niloticus +American alligator, Alligator mississipiensis +triceratops +thunder snake, worm snake, Carphophis amoenus +ringneck snake, ring-necked snake, ring snake +hognose snake, puff adder, sand viper +green snake, grass snake +king snake, kingsnake +garter snake, grass snake +water snake +vine snake +night snake, Hypsiglena torquata +boa constrictor, Constrictor constrictor +rock python, rock snake, Python sebae +Indian cobra, Naja naja +green mamba +sea snake +horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +diamondback, diamondback rattlesnake, Crotalus adamanteus +sidewinder, horned rattlesnake, Crotalus cerastes +trilobite +harvestman, daddy longlegs, Phalangium opilio +scorpion +black and gold garden spider, Argiope aurantia +barn spider, Araneus cavaticus +garden spider, Aranea diademata +black widow, Latrodectus mactans +tarantula +wolf spider, hunting spider +tick +centipede +black grouse +ptarmigan +ruffed grouse, partridge, Bonasa umbellus +prairie chicken, prairie grouse, prairie fowl +peacock +quail +partridge +African grey, African gray, Psittacus erithacus +macaw +sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser, Mergus serrator +goose +black swan, Cygnus atratus +tusker +echidna, spiny anteater, anteater +platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +wallaby, brush kangaroo +koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +wombat +jellyfish +sea anemone, anemone +brain coral +flatworm, platyhelminth +nematode, nematode worm, roundworm +conch +snail +slug +sea slug, nudibranch +chiton, coat-of-mail shell, sea cradle, polyplacophore +chambered nautilus, pearly nautilus, nautilus +Dungeness crab, Cancer magister +rock crab, Cancer irroratus +fiddler crab +king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +American lobster, Northern lobster, Maine lobster, Homarus americanus +spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +crayfish, crawfish, crawdad, crawdaddy +hermit crab +isopod +white stork, Ciconia ciconia +black stork, Ciconia nigra +spoonbill +flamingo +little blue heron, Egretta caerulea +American egret, great white heron, Egretta albus +bittern +crane +limpkin, Aramus pictus +European gallinule, Porphyrio porphyrio +American coot, marsh hen, mud hen, water hen, Fulica americana +bustard +ruddy turnstone, Arenaria interpres +red-backed sandpiper, dunlin, Erolia alpina +redshank, Tringa totanus +dowitcher +oystercatcher, oyster catcher +pelican +king penguin, Aptenodytes patagonica +albatross, mollymawk +grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +killer whale, killer, orca, grampus, sea wolf, Orcinus orca +dugong, Dugong dugon +sea lion +Chihuahua +Japanese spaniel +Maltese dog, Maltese terrier, Maltese +Pekinese, Pekingese, Peke +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound, Afghan +basset, basset hound +beagle +bloodhound, sleuthhound +bluetick +black-and-tan coonhound +Walker hound, Walker foxhound +English foxhound +redbone +borzoi, Russian wolfhound +Irish wolfhound +Italian greyhound +whippet +Ibizan hound, Ibizan Podenco +Norwegian elkhound, elkhound +otterhound, otter hound +Saluki, gazelle hound +Scottish deerhound, deerhound +Weimaraner +Staffordshire bullterrier, Staffordshire bull terrier +American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier, Sealyham +Airedale, Airedale terrier +cairn, cairn terrier +Australian terrier +Dandie Dinmont, Dandie Dinmont terrier +Boston bull, Boston terrier +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier, Scottish terrier, Scottie +Tibetan terrier, chrysanthemum dog +silky terrier, Sydney silky +soft-coated wheaten terrier +West Highland white terrier +Lhasa, Lhasa apso +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla, Hungarian pointer +English setter +Irish setter, red setter +Gordon setter +Brittany spaniel +clumber, clumber spaniel +English springer, English springer spaniel +Welsh springer spaniel +cocker spaniel, English cocker spaniel, cocker +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog, bobtail +Shetland sheepdog, Shetland sheep dog, Shetland +collie +Border collie +Bouvier des Flandres, Bouviers des Flandres +Rottweiler +German shepherd, German shepherd dog, German police dog, alsatian +Doberman, Doberman pinscher +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard, St Bernard +Eskimo dog, husky +malamute, malemute, Alaskan malamute +Siberian husky +dalmatian, coach dog, carriage dog +affenpinscher, monkey pinscher, monkey dog +basenji +pug, pug-dog +Leonberg +Newfoundland, Newfoundland dog +Great Pyrenees +Samoyed, Samoyede +Pomeranian +chow, chow chow +keeshond +Brabancon griffon +Pembroke, Pembroke Welsh corgi +Cardigan, Cardigan Welsh corgi +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf, grey wolf, gray wolf, Canis lupus +white wolf, Arctic wolf, Canis lupus tundrarum +red wolf, maned wolf, Canis rufus, Canis niger +coyote, prairie wolf, brush wolf, Canis latrans +dingo, warrigal, warragal, Canis dingo +dhole, Cuon alpinus +African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +hyena, hyaena +red fox, Vulpes vulpes +kit fox, Vulpes macrotis +Arctic fox, white fox, Alopex lagopus +grey fox, gray fox, Urocyon cinereoargenteus +tabby, tabby cat +tiger cat +Persian cat +Siamese cat, Siamese +Egyptian cat +cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +lynx, catamount +leopard, Panthera pardus +snow leopard, ounce, Panthera uncia +jaguar, panther, Panthera onca, Felis onca +lion, king of beasts, Panthera leo +tiger, Panthera tigris +cheetah, chetah, Acinonyx jubatus +brown bear, bruin, Ursus arctos +American black bear, black bear, Ursus americanus, Euarctos americanus +ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +sloth bear, Melursus ursinus, Ursus ursinus +mongoose +meerkat, mierkat +tiger beetle +ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +ground beetle, carabid beetle +long-horned beetle, longicorn, longicorn beetle +leaf beetle, chrysomelid +dung beetle +rhinoceros beetle +weevil +fly +bee +ant, emmet, pismire +grasshopper, hopper +cricket +walking stick, walkingstick, stick insect +cockroach, roach +mantis, mantid +cicada, cicala +leafhopper +lacewing, lacewing fly +dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +damselfly +admiral +ringlet, ringlet butterfly +monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +cabbage butterfly +sulphur butterfly, sulfur butterfly +lycaenid, lycaenid butterfly +starfish, sea star +sea urchin +sea cucumber, holothurian +wood rabbit, cottontail, cottontail rabbit +hare +Angora, Angora rabbit +hamster +porcupine, hedgehog +fox squirrel, eastern fox squirrel, Sciurus niger +marmot +beaver +guinea pig, Cavia cobaya +sorrel +zebra +hog, pig, grunter, squealer, Sus scrofa +wild boar, boar, Sus scrofa +warthog +hippopotamus, hippo, river horse, Hippopotamus amphibius +ox +water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +bison +ram, tup +bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +ibex, Capra ibex +hartebeest +impala, Aepyceros melampus +gazelle +Arabian camel, dromedary, Camelus dromedarius +llama +weasel +mink +polecat, fitch, foulmart, foumart, Mustela putorius +black-footed ferret, ferret, Mustela nigripes +otter +skunk, polecat, wood pussy +badger +armadillo +three-toed sloth, ai, Bradypus tridactylus +orangutan, orang, orangutang, Pongo pygmaeus +gorilla, Gorilla gorilla +chimpanzee, chimp, Pan troglodytes +gibbon, Hylobates lar +siamang, Hylobates syndactylus, Symphalangus syndactylus +guenon, guenon monkey +patas, hussar monkey, Erythrocebus patas +baboon +macaque +langur +colobus, colobus monkey +proboscis monkey, Nasalis larvatus +marmoset +capuchin, ringtail, Cebus capucinus +howler monkey, howler +titi, titi monkey +spider monkey, Ateles geoffroyi +squirrel monkey, Saimiri sciureus +Madagascar cat, ring-tailed lemur, Lemur catta +indri, indris, Indri indri, Indri brevicaudatus +Indian elephant, Elephas maximus +African elephant, Loxodonta africana +lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +barracouta, snoek +eel +coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +rock beauty, Holocanthus tricolor +anemone fish +sturgeon +gar, garfish, garpike, billfish, Lepisosteus osseus +lionfish +puffer, pufferfish, blowfish, globefish +abacus +abaya +academic gown, academic robe, judge's robe +accordion, piano accordion, squeeze box +acoustic guitar +aircraft carrier, carrier, flattop, attack aircraft carrier +airliner +airship, dirigible +altar +ambulance +amphibian, amphibious vehicle +analog clock +apiary, bee house +apron +ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +assault rifle, assault gun +backpack, back pack, knapsack, packsack, rucksack, haversack +bakery, bakeshop, bakehouse +balance beam, beam +balloon +ballpoint, ballpoint pen, ballpen, Biro +Band Aid +banjo +bannister, banister, balustrade, balusters, handrail +barbell +barber chair +barbershop +barn +barometer +barrel, cask +barrow, garden cart, lawn cart, wheelbarrow +baseball +basketball +bassinet +bassoon +bathing cap, swimming cap +bath towel +bathtub, bathing tub, bath, tub +beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +beacon, lighthouse, beacon light, pharos +beaker +bearskin, busby, shako +beer bottle +beer glass +bell cote, bell cot +bib +bicycle-built-for-two, tandem bicycle, tandem +bikini, two-piece +binder, ring-binder +binoculars, field glasses, opera glasses +birdhouse +boathouse +bobsled, bobsleigh, bob +bolo tie, bolo, bola tie, bola +bonnet, poke bonnet +bookcase +bookshop, bookstore, bookstall +bottlecap +bow +bow tie, bow-tie, bowtie +brass, memorial tablet, plaque +brassiere, bra, bandeau +breakwater, groin, groyne, mole, bulwark, seawall, jetty +breastplate, aegis, egis +broom +bucket, pail +buckle +bulletproof vest +bullet train, bullet +butcher shop, meat market +cab, hack, taxi, taxicab +caldron, cauldron +candle, taper, wax light +cannon +canoe +can opener, tin opener +cardigan +car mirror +carousel, carrousel, merry-go-round, roundabout, whirligig +carpenter's kit, tool kit +carton +car wheel +cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +cassette +cassette player +castle +catamaran +CD player +cello, violoncello +cellular telephone, cellular phone, cellphone, cell, mobile phone +chain +chainlink fence +chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +chain saw, chainsaw +chest +chiffonier, commode +chime, bell, gong +china cabinet, china closet +Christmas stocking +church, church building +cinema, movie theater, movie theatre, movie house, picture palace +cleaver, meat cleaver, chopper +cliff dwelling +cloak +clog, geta, patten, sabot +cocktail shaker +coffee mug +coffeepot +coil, spiral, volute, whorl, helix +combination lock +computer keyboard, keypad +confectionery, confectionary, candy store +container ship, containership, container vessel +convertible +corkscrew, bottle screw +cornet, horn, trumpet, trump +cowboy boot +cowboy hat, ten-gallon hat +cradle +crane +crash helmet +crate +crib, cot +Crock Pot +croquet ball +crutch +cuirass +dam, dike, dyke +desk +desktop computer +dial telephone, dial phone +diaper, nappy, napkin +digital clock +digital watch +dining table, board +dishrag, dishcloth +dishwasher, dish washer, dishwashing machine +disk brake, disc brake +dock, dockage, docking facility +dogsled, dog sled, dog sleigh +dome +doormat, welcome mat +drilling platform, offshore rig +drum, membranophone, tympan +drumstick +dumbbell +Dutch oven +electric fan, blower +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa, boa +file, file cabinet, filing cabinet +fireboat +fire engine, fire truck +fire screen, fireguard +flagpole, flagstaff +flute, transverse flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn, horn +frying pan, frypan, skillet +fur coat +garbage truck, dustcart +gasmask, respirator, gas helmet +gas pump, gasoline pump, petrol pump, island dispenser +goblet +go-kart +golf ball +golfcart, golf cart +gondola +gong, tam-tam +gown +grand piano, grand +greenhouse, nursery, glasshouse +grille, radiator grille +grocery store, grocery, food market, market +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower, blow dryer, blow drier, hair dryer, hair drier +hand-held computer, hand-held microcomputer +handkerchief, hankie, hanky, hankey +hard disc, hard disk, fixed disk +harmonica, mouth organ, harp, mouth harp +harp +harvester, reaper +hatchet +holster +home theater, home theatre +honeycomb +hook, claw +hoopskirt, crinoline +horizontal bar, high bar +horse cart, horse-cart +hourglass +iPod +iron, smoothing iron +jack-o'-lantern +jean, blue jean, denim +jeep, landrover +jersey, T-shirt, tee shirt +jigsaw puzzle +jinrikisha, ricksha, rickshaw +joystick +kimono +knee pad +knot +lab coat, laboratory coat +ladle +lampshade, lamp shade +laptop, laptop computer +lawn mower, mower +lens cap, lens cover +letter opener, paper knife, paperknife +library +lifeboat +lighter, light, igniter, ignitor +limousine, limo +liner, ocean liner +lipstick, lip rouge +Loafer +lotion +loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +loupe, jeweler's loupe +lumbermill, sawmill +magnetic compass +mailbag, postbag +mailbox, letter box +maillot +maillot, tank suit +manhole cover +maraca +marimba, xylophone +mask +matchstick +maypole +maze, labyrinth +measuring cup +medicine chest, medicine cabinet +megalith, megalithic structure +microphone, mike +microwave, microwave oven +military uniform +milk can +minibus +miniskirt, mini +minivan +missile +mitten +mixing bowl +mobile home, manufactured home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter, scooter +mountain bike, all-terrain bike, off-roader +mountain tent +mouse, computer mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook, notebook computer +obelisk +oboe, hautboy, hautbois +ocarina, sweet potato +odometer, hodometer, mileometer, milometer +oil filter +organ, pipe organ +oscilloscope, scope, cathode-ray oscilloscope, CRO +overskirt +oxcart +oxygen mask +packet +paddle, boat paddle +paddlewheel, paddle wheel +padlock +paintbrush +pajama, pyjama, pj's, jammies +palace +panpipe, pandean pipe, syrinx +paper towel +parachute, chute +parallel bars, bars +park bench +parking meter +passenger car, coach, carriage +patio, terrace +pay-phone, pay-station +pedestal, plinth, footstall +pencil box, pencil case +pencil sharpener +perfume, essence +Petri dish +photocopier +pick, plectrum, plectron +pickelhaube +picket fence, paling +pickup, pickup truck +pier +piggy bank, penny bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate, pirate ship +pitcher, ewer +plane, carpenter's plane, woodworking plane +planetarium +plastic bag +plate rack +plow, plough +plunger, plumber's helper +Polaroid camera, Polaroid Land camera +pole +police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +poncho +pool table, billiard table, snooker table +pop bottle, soda bottle +pot, flowerpot +potter's wheel +power drill +prayer rug, prayer mat +printer +prison, prison house +projectile, missile +projector +puck, hockey puck +punching bag, punch bag, punching ball, punchball +purse +quill, quill pen +quilt, comforter, comfort, puff +racer, race car, racing car +racket, racquet +radiator +radio, wireless +radio telescope, radio reflector +rain barrel +recreational vehicle, RV, R.V. +reel +reflex camera +refrigerator, icebox +remote control, remote +restaurant, eating house, eating place, eatery +revolver, six-gun, six-shooter +rifle +rocking chair, rocker +rotisserie +rubber eraser, rubber, pencil eraser +rugby ball +rule, ruler +running shoe +safe +safety pin +saltshaker, salt shaker +sandal +sarong +sax, saxophone +scabbard +scale, weighing machine +school bus +schooner +scoreboard +screen, CRT screen +screw +screwdriver +seat belt, seatbelt +sewing machine +shield, buckler +shoe shop, shoe-shop, shoe store +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule, slipstick +sliding door +slot, one-armed bandit +snorkel +snowmobile +snowplow, snowplough +soap dispenser +soccer ball +sock +solar dish, solar collector, solar furnace +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web, spider's web +spindle +sports car, sport car +spotlight, spot +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch, stop watch +stove +strainer +streetcar, tram, tramcar, trolley, trolley car +stretcher +studio couch, day bed +stupa, tope +submarine, pigboat, sub, U-boat +suit, suit of clothes +sundial +sunglass +sunglasses, dark glasses, shades +sunscreen, sunblock, sun blocker +suspension bridge +swab, swob, mop +sweatshirt +swimming trunks, bathing trunks +swing +switch, electric switch, electrical switch +syringe +table lamp +tank, army tank, armored combat vehicle, armoured combat vehicle +tape player +teapot +teddy, teddy bear +television, television system +tennis ball +thatch, thatched roof +theater curtain, theatre curtain +thimble +thresher, thrasher, threshing machine +throne +tile roof +toaster +tobacco shop, tobacconist shop, tobacconist +toilet seat +torch +totem pole +tow truck, tow car, wrecker +toyshop +tractor +trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +tray +trench coat +tricycle, trike, velocipede +trimaran +tripod +triumphal arch +trolleybus, trolley coach, trackless trolley +trombone +tub, vat +turnstile +typewriter keyboard +umbrella +unicycle, monocycle +upright, upright piano +vacuum, vacuum cleaner +vase +vault +velvet +vending machine +vestment +viaduct +violin, fiddle +volleyball +waffle iron +wall clock +wallet, billfold, notecase, pocketbook +wardrobe, closet, press +warplane, military plane +washbasin, handbasin, washbowl, lavabo, wash-hand basin +washer, automatic washer, washing machine +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool, woolen, woollen +worm fence, snake fence, snake-rail fence, Virginia fence +wreck +yawl +yurt +web site, website, internet site, site +comic book +crossword puzzle, crossword +street sign +traffic light, traffic signal, stoplight +book jacket, dust cover, dust jacket, dust wrapper +menu +plate +guacamole +consomme +hot pot, hotpot +trifle +ice cream, icecream +ice lolly, lolly, lollipop, popsicle +French loaf +bagel, beigel +pretzel +cheeseburger +hotdog, hot dog, red hot +mashed potato +head cabbage +broccoli +cauliflower +zucchini, courgette +spaghetti squash +acorn squash +butternut squash +cucumber, cuke +artichoke, globe artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple, ananas +banana +jackfruit, jak, jack +custard apple +pomegranate +hay +carbonara +chocolate sauce, chocolate syrup +dough +meat loaf, meatloaf +pizza, pizza pie +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff, drop, drop-off +coral reef +geyser +lakeside, lakeshore +promontory, headland, head, foreland +sandbar, sand bar +seashore, coast, seacoast, sea-coast +valley, vale +volcano +ballplayer, baseball player +groom, bridegroom +scuba diver +rapeseed +daisy +yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +corn +acorn +hip, rose hip, rosehip +buckeye, horse chestnut, conker +coral fungus +agaric +gyromitra +stinkhorn, carrion fungus +earthstar +hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +bolete +ear, spike, capitulum +toilet tissue, toilet paper, bathroom tissue diff --git a/mediapipe/models/trace_mobilenetv2.py b/mediapipe/models/trace_mobilenetv2.py new file mode 100644 index 000000000..ff993287d --- /dev/null +++ b/mediapipe/models/trace_mobilenetv2.py @@ -0,0 +1,14 @@ +# `torch` package known to work: +# python -m pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +import torch + +# More about this model at https://pytorch.org/hub/pytorch_vision_mobilenet_v2 + +if __name__ == '__main__': + model = torch.hub.load('pytorch/vision:v0.5.0', + 'mobilenet_v2', + pretrained=True) + model.eval() + input_tensor = torch.rand(1, 3, 224, 224) + script_model = torch.jit.trace(model, input_tensor) + script_model.save("mediapipe/models/mobilenetv2.pt") diff --git a/third_party/BUILD b/third_party/BUILD index 4d2676751..89b28818a 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -209,6 +209,21 @@ cc_library( }), ) +cc_library( + name = "libtorch", + visibility = ["//visibility:public"], + deps = select({ + "//mediapipe:android_x86": [], + "//mediapipe:android_x86_64": [], + "//mediapipe:android_armeabi": [], + "//mediapipe:android_arm": [], + "//mediapipe:android_arm64": [], + "//mediapipe:ios": ["@ios_libtorch//:libtorch"], + "//mediapipe:macos": ["@macos_libtorch_cpu//:libtorch_cpu"], + "//conditions:default": ["@linux_libtorch_cpu//:libtorch_cpu"], + }), +) + android_library( name = "androidx_annotation", exports = [ diff --git a/third_party/libtorch_linux.BUILD b/third_party/libtorch_linux.BUILD new file mode 100644 index 000000000..a78651592 --- /dev/null +++ b/third_party/libtorch_linux.BUILD @@ -0,0 +1,32 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "libtorch_cpu", + srcs = glob([ + "lib/libc10.so", + "lib/libgomp-*.so.1", + "lib/libtorch.so", + ]), + hdrs = glob(["include/**/*.h"]), + includes = [ + "include", + "include/TH", + "include/THC", + "include/torch/csrc/api/include", + ], + visibility = ["//visibility:public"], +)