diff --git a/WORKSPACE b/WORKSPACE
index eb3efd275..1bd865c1e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -361,6 +361,19 @@ http_archive(
],
)
+# PyTorch archives
+
+# 2020-03-12
+# On https://pytorch.org: Stable > Linux > LibTorch > C++ > (CUDA) None > cxx11 ABI
+http_archive(
+ name = "linux_libtorch_cpu",
+ build_file = "@//third_party:libtorch_linux.BUILD",
+ sha256 = "33a9dd142d0497375db42b055bd90780f9d92047a19edc8891e6232e2b5bdba7",
+ strip_prefix = "libtorch",
+ type = "zip",
+ url = "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip",
+)
+
#Tensorflow repo should always go after the other external dependencies.
# 2020-08-30
_TENSORFLOW_GIT_COMMIT = "57b009e31e59bd1a7ae85ef8c0232ed86c9b71db"
diff --git a/docs/solutions/object_classification.md b/docs/solutions/object_classification.md
new file mode 100644
index 000000000..e5c69a29d
--- /dev/null
+++ b/docs/solutions/object_classification.md
@@ -0,0 +1,79 @@
+---
+layout: default
+title: Object Classification
+parent: Solutions
+nav_order: TODO
+---
+
+# MediaPipe Object Classification
+{: .no_toc }
+
+1. TOC
+{:toc}
+---
+
+## Example Apps
+
+Note: To visualize a graph, copy the graph and paste it into
+[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
+to visualize its associated subgraphs, please see
+[visualizer documentation](../tools/visualizer.md).
+
+
+
+### Desktop
+
+#### Live Camera Input
+
+Please first see general instructions for
+[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples.
+
+* Graph:
+ [`mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt)
+* Target:
+ [`mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/object_classification/BUILD)
+
+#### Video File Input
+
+* With a PyTorch Model
+
+ This uses a MobileNetv2 trace model from PyTorch Hub. To fetch and prepare it, run:
+
+ ```bash
+ python mediapipe/models/trace_mobilenetv2.py
+ ```
+
+ The pipeline is implemented in this
+ [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt).
+
+ To build the application, run:
+
+ ```bash
+ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu
+ ```
+
+ To run the application, replace ` ` and `` in the command below with your own paths:
+
+ Tip: You can find a test video available in
+ `mediapipe/examples/desktop/object_detection`.
+
+ ```
+ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_classification/object_classification_pytorch_cpu \
+ --calculator_graph_config_file=mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt \
+ --input_side_packets=input_video_path= ,output_video_path=
+ ```
+
\ No newline at end of file
diff --git a/mediapipe/calculators/pytorch/BUILD b/mediapipe/calculators/pytorch/BUILD
new file mode 100644
index 000000000..8880fca5d
--- /dev/null
+++ b/mediapipe/calculators/pytorch/BUILD
@@ -0,0 +1,138 @@
+# Copyright 2020 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_proto//proto:defs.bzl", "proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_library")
+load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:private"])
+
+proto_library(
+ name = "pytorch_converter_calculator_proto",
+ srcs = ["pytorch_converter_calculator.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//mediapipe/framework:calculator_proto"],
+)
+
+proto_library(
+ name = "pytorch_inference_calculator_proto",
+ srcs = ["pytorch_inference_calculator.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//mediapipe/framework:calculator_proto"],
+)
+
+proto_library(
+ name = "pytorch_tensors_to_classification_calculator_proto",
+ srcs = ["pytorch_tensors_to_classification_calculator.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//mediapipe/framework:calculator_proto"],
+)
+
+mediapipe_cc_proto_library(
+ name = "pytorch_converter_calculator_cc_proto",
+ srcs = ["pytorch_converter_calculator.proto"],
+ cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
+ visibility = ["//visibility:public"],
+ deps = [":pytorch_converter_calculator_proto"],
+)
+
+mediapipe_cc_proto_library(
+ name = "pytorch_inference_calculator_cc_proto",
+ srcs = ["pytorch_inference_calculator.proto"],
+ cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
+ visibility = ["//visibility:public"],
+ deps = [":pytorch_inference_calculator_proto"],
+)
+
+mediapipe_cc_proto_library(
+ name = "pytorch_tensors_to_classification_calculator_cc_proto",
+ srcs = ["pytorch_tensors_to_classification_calculator.proto"],
+ cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
+ visibility = ["//visibility:public"],
+ deps = [":pytorch_tensors_to_classification_calculator_proto"],
+)
+
+cc_library(
+ name = "pytorch_converter_calculator",
+ srcs = ["pytorch_converter_calculator.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pytorch_converter_calculator_cc_proto",
+ "//mediapipe/framework:calculator_framework",
+ "//mediapipe/framework/formats:image_frame",
+ "//mediapipe/framework/formats:image_frame_opencv",
+ "//mediapipe/framework/port:opencv_imgproc",
+ "//mediapipe/framework/port:ret_check",
+ "//mediapipe/framework/port:status",
+ "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
+ "//third_party:libtorch",
+ ] + select({
+ "//mediapipe:ios": [
+ "//mediapipe/gpu:gpu_buffer",
+ "//mediapipe/objc:util",
+ ],
+ "//conditions:default": [],
+ }),
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "pytorch_inference_calculator",
+ srcs = ["pytorch_inference_calculator.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pytorch_inference_calculator_cc_proto",
+ "//mediapipe/framework:calculator_framework",
+ "//mediapipe/framework/port:ret_check",
+ "//mediapipe/framework/port:status",
+ "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
+ "//mediapipe/util:resource_util",
+ "//third_party:libtorch",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "pytorch_tensors_to_classification_calculator",
+ srcs = ["pytorch_tensors_to_classification_calculator.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pytorch_tensors_to_classification_calculator_cc_proto",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "//mediapipe/framework/port:status",
+ "//mediapipe/framework/formats:classification_cc_proto",
+ "//mediapipe/framework:calculator_framework",
+ "//mediapipe/framework/formats:location",
+ "//mediapipe/framework/port:ret_check",
+ "//mediapipe/util:resource_util",
+ "//third_party:libtorch",
+ ] + select({
+ "//mediapipe:android": [
+ "//mediapipe/util/android/file/base",
+ ],
+ "//mediapipe:apple": [
+ "//mediapipe/util/android/file/base",
+ ],
+ "//mediapipe:macos": [
+ "//mediapipe/framework/port:file_helpers",
+ ],
+ "//conditions:default": [
+ "//mediapipe/framework/port:file_helpers",
+ ],
+ }),
+ alwayslink = 1,
+)
diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc
new file mode 100644
index 000000000..b2696f73f
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc
@@ -0,0 +1,179 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+
+#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h"
+#include "mediapipe/framework/calculator_framework.h"
+#include "mediapipe/framework/formats/image_frame.h"
+#include "mediapipe/framework/formats/image_frame_opencv.h"
+#include "mediapipe/framework/port/opencv_imgproc_inc.h"
+#include "mediapipe/framework/port/ret_check.h"
+#include "mediapipe/framework/port/status.h"
+#include "torch/script.h"
+#include "torch/torch.h"
+
+#if defined(MEDIAPIPE_IOS)
+#include "mediapipe/gpu/gpu_buffer.h"
+#include "mediapipe/objc/util.h"
+#endif // iOS
+
+namespace mediapipe {
+
+namespace {
+constexpr char kImageTag[] = "IMAGE";
+constexpr char kImageGpuTag[] = "IMAGE_GPU";
+constexpr char kTensorsTag[] = "TENSORS";
+
+using Outputs = std::vector;
+} // namespace
+
+// Calculator for normalizing and converting an ImageFrame
+// into a PyTorchTensor.
+//
+// This calculator is designed to be used with the PyTorchInferenceCalculator,
+// as a pre-processing step for calculator inputs.
+//
+// IMAGE inputs are normalized to [-1,1] (default) or [0,1],
+// specified by options (unless outputting a quantized tensor).
+//
+// Input:
+// One of the following tags:
+// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
+//
+// Output:
+// One of the following tags:
+// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
+//
+// Example use:
+// node {
+// calculator: "PyTorchConverterCalculator"
+// input_stream: "IMAGE:input_image"
+// output_stream: "TENSORS:image_tensor"
+// options: {
+// [mediapipe.PyTorchConverterCalculatorOptions.ext] {
+// zero_center: true
+// }
+// }
+// }
+//
+// IMPORTANT Notes:
+// This calculator uses FixedSizeInputStreamHandler by default.
+//
+class PyTorchConverterCalculator : public CalculatorBase {
+ public:
+ static ::mediapipe::Status GetContract(CalculatorContract* cc);
+
+ ::mediapipe::Status Open(CalculatorContext* cc) override;
+ ::mediapipe::Status Process(CalculatorContext* cc) override;
+ ::mediapipe::Status Close(CalculatorContext* cc) override;
+
+ private:
+ ::mediapipe::PyTorchConverterCalculatorOptions options_;
+ bool has_image_tag_;
+ bool has_image_gpu_tag_;
+};
+REGISTER_CALCULATOR(PyTorchConverterCalculator);
+
+::mediapipe::Status PyTorchConverterCalculator::GetContract(
+ CalculatorContract* cc) {
+ const bool has_image_tag = cc->Inputs().HasTag(kImageTag);
+ const bool has_image_gpu_tag = cc->Inputs().HasTag(kImageGpuTag);
+ // Confirm only one of the input streams is present.
+ RET_CHECK(has_image_tag ^ has_image_gpu_tag);
+
+ const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag);
+ RET_CHECK(has_tensors_tag);
+
+ if (has_image_tag) cc->Inputs().Tag(kImageTag).Set();
+ if (has_image_gpu_tag) {
+#if defined(MEDIAPIPE_IOS)
+ cc->Inputs().Tag(kImageGpuTag).Set();
+#else
+ RET_CHECK_FAIL() << "GPU processing not enabled.";
+#endif
+ }
+
+ if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set();
+
+ // Assign this calculator's default InputStreamHandler.
+ cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchConverterCalculator::Open(CalculatorContext* cc) {
+ cc->SetOffset(TimestampDiff(0));
+
+ options_ = cc->Options<::mediapipe::PyTorchConverterCalculatorOptions>();
+
+ has_image_tag_ = cc->Inputs().HasTag(kImageTag);
+ has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag);
+
+ if (has_image_gpu_tag_) {
+#if !defined(MEDIAPIPE_IOS)
+ RET_CHECK_FAIL() << "GPU processing not enabled.";
+#endif
+ }
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) {
+ cv::Mat image;
+ if (has_image_gpu_tag_) {
+#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
+ const auto& input = cc->Inputs().Tag(kImageGpuTag).Get();
+ std::unique_ptr frame =
+ CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef());
+ image = mediapipe::formats::MatView(frame.release());
+#else
+ RET_CHECK_FAIL() << "GPU processing is not enabled.";
+#endif
+ }
+ if (has_image_tag_) {
+ auto& output_frame =
+ cc->Inputs().Tag(kImageTag).Get();
+ image = mediapipe::formats::MatView(&output_frame);
+ }
+ cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
+ const auto width = image.cols, height = image.rows,
+ num_channels = image.channels();
+ RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now";
+
+ cv::Mat img_float;
+ image.convertTo(img_float, CV_32F, 1.0 / 255);
+ auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3});
+ img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH
+ if (options_.per_channel_normalizations().size() > 0)
+ for (int i = 0; i < num_channels; ++i) {
+ const auto& subdiv = options_.per_channel_normalizations(i);
+ img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div());
+ }
+
+ auto output_tensors = absl::make_unique();
+ output_tensors->reserve(1);
+ output_tensors->emplace_back(img_tensor.cpu());
+ cc->Outputs()
+ .Tag(kTensorsTag)
+ .Add(output_tensors.release(), cc->InputTimestamp());
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) {
+ return ::mediapipe::OkStatus();
+}
+
+} // namespace mediapipe
diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto
new file mode 100644
index 000000000..22b3ac9c3
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto
@@ -0,0 +1,34 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+
+package mediapipe;
+
+import "mediapipe/framework/calculator.proto";
+
+message PyTorchConverterCalculatorOptions {
+ extend mediapipe.CalculatorOptions {
+ optional PyTorchConverterCalculatorOptions ext = 245877797;
+ }
+
+ message SubDiv {
+ optional float sub = 1 [default = 0];
+ optional float div = 2 [default = 1];
+ }
+ // Normalizations to apply per input image channel.
+ // There must be exactly as many items in this list as there are channels
+ // or none.
+ repeated SubDiv per_channel_normalizations = 1;
+}
diff --git a/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc b/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc
new file mode 100644
index 000000000..7573747a9
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_inference_calculator.cc
@@ -0,0 +1,172 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+
+#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h"
+#include "mediapipe/framework/calculator_framework.h"
+#include "mediapipe/framework/port/ret_check.h"
+#include "mediapipe/framework/port/status.h"
+#include "mediapipe/util/resource_util.h"
+#include "torch/script.h"
+#include "torch/torch.h"
+
+namespace mediapipe {
+
+namespace {
+constexpr char kTensorsTag[] = "TENSORS";
+
+using Inputs = std::vector;
+using Outputs = torch::Tensor;
+} // namespace
+
+// Calculator Header Section
+
+// Runs inference on the provided input TFLite tensors and TFLite model.
+//
+// Creates an interpreter with given model and calls invoke().
+// Optionally run inference on CPU/GPU.
+//
+// This calculator is designed to be used with the TfLiteConverterCalcualtor,
+// to get the appropriate inputs.
+//
+// When the input tensors are on CPU, gpu inference is optional and can be
+// specified in the calculator options.
+//
+// Input:
+// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
+//
+// Output:
+// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
+//
+// Example use:
+// node {
+// calculator: "PyTorchInferenceCalculator"
+// input_stream: "TENSORS:tensor_image"
+// output_stream: "TENSORS:tensors"
+// options: {
+// [mediapipe.PyTorchInferenceCalculatorOptions.ext] {
+// model_path: "modelname.tflite"
+// delegate { gpu {} }
+// }
+// }
+// }
+//
+// IMPORTANT Notes:
+// This calculator uses FixedSizeInputStreamHandler by default.
+//
+class PyTorchInferenceCalculator : public CalculatorBase {
+ public:
+ static ::mediapipe::Status GetContract(CalculatorContract* cc);
+
+ ::mediapipe::Status Open(CalculatorContext* cc) override;
+ ::mediapipe::Status Process(CalculatorContext* cc) override;
+ ::mediapipe::Status Close(CalculatorContext* cc) override;
+
+ private:
+ ::mediapipe::PyTorchInferenceCalculatorOptions options_;
+ torch::jit::script::Module module_;
+ torch::jit::IValue hidden_state_;
+};
+REGISTER_CALCULATOR(PyTorchInferenceCalculator);
+
+// Calculator Core Section
+
+::mediapipe::Status PyTorchInferenceCalculator::GetContract(
+ CalculatorContract* cc) {
+ RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
+ RET_CHECK(cc->Outputs().HasTag(kTensorsTag));
+
+ if (cc->Inputs().HasTag(kTensorsTag))
+ cc->Inputs().Tag(kTensorsTag).Set();
+
+ if (cc->Outputs().HasTag(kTensorsTag))
+ cc->Outputs().Tag(kTensorsTag).Set();
+
+ // Assign this calculator's default InputStreamHandler.
+ cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchInferenceCalculator::Open(CalculatorContext* cc) {
+ cc->SetOffset(TimestampDiff(0));
+
+ options_ = cc->Options<::mediapipe::PyTorchInferenceCalculatorOptions>();
+
+ std::string model_path = options_.model_path();
+ ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path));
+ try {
+ // https://github.com/pytorch/ios-demo-app/issues/8#issuecomment-612996683
+ auto qengines = at::globalContext().supportedQEngines();
+ if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
+ qengines.end()) {
+ LOG(INFO) << "Using QEngine at::QEngine::QNNPACK";
+ at::globalContext().setQEngine(at::QEngine::QNNPACK);
+ }
+#if defined(MEDIAPIPE_IOS)
+ else {
+ RET_CHECK_FAIL() << "QEngine::QNNPACK is required for iOS";
+ }
+#endif
+
+ module_ = torch::jit::load(model_path);
+ module_.eval();
+ } catch (const std::exception& e) {
+ LOG(ERROR) << e.what();
+ return ::mediapipe::UnknownError(e.what());
+ }
+
+ if (options_.model_has_hidden_state())
+ hidden_state_ = torch::zeros({1, 1, 10}); // TODO: read options
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchInferenceCalculator::Process(CalculatorContext* cc) {
+ const auto inputs = cc->Inputs().Tag(kTensorsTag).Get();
+ RET_CHECK_GT(inputs.size(), 0);
+
+ // Disables autograd
+ torch::autograd::AutoGradMode guard(false);
+ // Disables autograd even more? https://github.com/pytorch/pytorch/pull/26477
+ at::AutoNonVariableTypeMode non_var_type_mode(true);
+
+ Outputs out_tensor;
+ if (options_.model_has_hidden_state()) {
+ RET_CHECK_EQ(inputs.size(), 1) << "Not sure how to forward() hidden state";
+ // auto tuple = torch::ivalue::Tuple::create({inp,hidden_state_});
+ const auto result = module_.forward({{inputs[0], hidden_state_}});
+ const auto out = result.toTuple()->elements();
+ out_tensor = out[0].toTensor();
+ hidden_state_ = out[1].toTensor();
+ } else {
+ const auto result = module_.forward(std::move(inputs));
+ out_tensor = result.toTensor();
+ }
+
+ auto out = absl::make_unique(out_tensor);
+ cc->Outputs().Tag(kTensorsTag).Add(out.release(), cc->InputTimestamp());
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchInferenceCalculator::Close(CalculatorContext* cc) {
+ return ::mediapipe::OkStatus();
+}
+
+} // namespace mediapipe
diff --git a/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto b/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto
new file mode 100644
index 000000000..8a93d2608
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_inference_calculator.proto
@@ -0,0 +1,71 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+
+package mediapipe;
+
+import "mediapipe/framework/calculator.proto";
+
+message PyTorchInferenceCalculatorOptions {
+ extend mediapipe.CalculatorOptions {
+ optional PyTorchInferenceCalculatorOptions ext = 233877213;
+ }
+
+ // message Delegate {
+ // // Default inference provided by tflite.
+ // message TfLite {}
+ // // Delegate to run GPU inference depending on the device.
+ // // (Can use OpenGl, OpenCl, Metal depending on the device.)
+ // message Gpu {}
+ // // Android only.
+ // message Nnapi {}
+
+ // oneof delegate {
+ // TfLite tflite = 1;
+ // Gpu gpu = 2;
+ // Nnapi nnapi = 3;
+ // }
+ // }
+
+ // Path to the PyTorch trace model (ex: /path/to/modelname.pt).
+ optional string model_path = 1;
+
+ // Whether model has hidden state.
+ optional bool model_has_hidden_state = 2;
+
+ // // Whether the TF Lite GPU or CPU backend should be used. Effective only
+ // when
+ // // input tensors are on CPU. For input tensors on GPU, GPU backend is
+ // always
+ // // used.
+ // // DEPRECATED: configure "delegate" instead.
+ // optional bool use_gpu = 2 [deprecated = true, default = false];
+
+ // // Android only. When true, an NNAPI delegate will be used for inference.
+ // // If NNAPI is not available, then the default CPU delegate will be used
+ // // automatically.
+ // // DEPRECATED: configure "delegate" instead.
+ // optional bool use_nnapi = 3 [deprecated = true, default = false];
+
+ // // The number of threads available to the interpreter. Effective only when
+ // // input tensors are on CPU and 'use_gpu' is false.
+ // optional int32 cpu_num_thread = 4 [default = -1];
+
+ // // TfLite delegate to run inference.
+ // // NOTE: calculator is free to choose delegate if not specified explicitly.
+ // // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes
+ // // precedence over use_* deprecated options.)
+ // optional Delegate delegate = 5;
+}
diff --git a/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc
new file mode 100644
index 000000000..b04b9a881
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.cc
@@ -0,0 +1,173 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+
+#include "absl/strings/str_format.h"
+#include "absl/types/span.h"
+#include "mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.pb.h"
+#include "mediapipe/framework/calculator_framework.h"
+#include "mediapipe/framework/formats/classification.pb.h"
+#include "mediapipe/framework/port/ret_check.h"
+#include "mediapipe/framework/port/status.h"
+#include "mediapipe/util/resource_util.h"
+#include "torch/torch.h"
+#if defined(MEDIAPIPE_MOBILE)
+#include "mediapipe/util/android/file/base/file.h"
+#include "mediapipe/util/android/file/base/helpers.h"
+#else
+#include "mediapipe/framework/port/file_helpers.h"
+#endif
+
+namespace mediapipe {
+
+namespace {
+constexpr char kTensorsTag[] = "TENSORS";
+constexpr char kClassificationListTag[] = "CLASSIFICATION_LIST";
+
+using Inputs = torch::Tensor;
+} // namespace
+
+// Convert result PyTorch tensors from classification models into MediaPipe
+// classifications.
+//
+// Input:
+// TENSORS - Vector of PyTorch of type kTfLiteFloat32 containing one
+// tensor, the size of which must be (1, * num_classes).
+// Output:
+// CLASSIFICATIONS - Result MediaPipe ClassificationList. The score and index
+// fields of each classification are set, while the label
+// field is only set if label_map_path is provided.
+//
+// Usage example:
+// node {
+// calculator: "PyTorchTensorsToClassificationCalculator"
+// input_stream: "TENSORS:tensors"
+// output_stream: "CLASSIFICATIONS:classifications"
+// options: {
+// [mediapipe.PyTorchTensorsToClassificationCalculatorOptions.ext] {
+// num_classes: 1024
+// min_score_threshold: 0.1
+// label_map_path: "labelmap.txt"
+// }
+// }
+// }
+class PyTorchTensorsToClassificationCalculator : public CalculatorBase {
+ public:
+ static ::mediapipe::Status GetContract(CalculatorContract* cc);
+
+ ::mediapipe::Status Open(CalculatorContext* cc) override;
+ ::mediapipe::Status Process(CalculatorContext* cc) override;
+ ::mediapipe::Status Close(CalculatorContext* cc) override;
+
+ private:
+ ::mediapipe::PyTorchTensorsToClassificationCalculatorOptions options_;
+ std::unordered_map label_map_;
+ bool label_map_loaded_ = false;
+};
+REGISTER_CALCULATOR(PyTorchTensorsToClassificationCalculator);
+
+::mediapipe::Status PyTorchTensorsToClassificationCalculator::GetContract(
+ CalculatorContract* cc) {
+ RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
+ cc->Inputs().Tag(kTensorsTag).Set();
+
+ RET_CHECK(!cc->Outputs().GetTags().empty());
+ if (cc->Outputs().HasTag(kClassificationListTag)) {
+ cc->Outputs().Tag(kClassificationListTag).Set();
+ }
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchTensorsToClassificationCalculator::Open(
+ CalculatorContext* cc) {
+ cc->SetOffset(TimestampDiff(0));
+
+ options_ = cc->Options<
+ ::mediapipe::PyTorchTensorsToClassificationCalculatorOptions>();
+
+ if (options_.has_label_map_path()) {
+ std::string string_path;
+ ASSIGN_OR_RETURN(string_path,
+ PathToResourceAsFile(options_.label_map_path()));
+ std::string label_map_string;
+ MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
+
+ std::istringstream stream(label_map_string);
+ std::string line;
+ int i = 0;
+ while (std::getline(stream, line)) label_map_[i++] = line;
+ label_map_loaded_ = true;
+ }
+
+ if (options_.has_top_k()) RET_CHECK_GT(options_.top_k(), 0);
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchTensorsToClassificationCalculator::Process(
+ CalculatorContext* cc) {
+ const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get();
+ RET_CHECK_EQ(input_tensors.dim(), 2);
+
+ std::tuple result =
+ input_tensors.sort(/*dim*/ -1, /*descending*/ true);
+ const torch::Tensor scores_tensor = std::get<0>(result)[0];
+ const torch::Tensor indices_tensor =
+ std::get<1>(result)[0].toType(torch::kInt32);
+
+ auto scores = scores_tensor.accessor();
+ auto indices = indices_tensor.accessor();
+
+ const auto indices_count = indices.size(0);
+ RET_CHECK_EQ(indices_count, scores.size(0));
+ if (label_map_loaded_)
+ RET_CHECK_EQ(indices_count, label_map_.size())
+ << "need: " << indices_count << ", got: " << label_map_.size();
+
+ // RET_CHECK_GE(indices_count, options_.top_k());
+ auto top_k = indices.size(0);
+ if (options_.has_top_k()) top_k = options_.top_k();
+
+ auto classification_list = absl::make_unique();
+ for (int i = 0; i < indices_count; ++i) {
+ if (classification_list->classification_size() == top_k) break;
+ const float score = scores[i];
+ const int index = indices[i];
+ if (options_.has_min_score_threshold() &&
+ score < options_.min_score_threshold())
+ continue;
+
+ Classification* classification = classification_list->add_classification();
+ classification->set_score(score);
+ classification->set_index(index);
+ if (label_map_loaded_) classification->set_label(label_map_[index]);
+ }
+
+ cc->Outputs()
+ .Tag(kClassificationListTag)
+ .Add(classification_list.release(), cc->InputTimestamp());
+
+ return ::mediapipe::OkStatus();
+}
+
+::mediapipe::Status PyTorchTensorsToClassificationCalculator::Close(
+ CalculatorContext* cc) {
+ return ::mediapipe::OkStatus();
+}
+
+} // namespace mediapipe
diff --git a/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto
new file mode 100644
index 000000000..efac275c6
--- /dev/null
+++ b/mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.proto
@@ -0,0 +1,33 @@
+// Copyright 2020 The MediaPipe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+
+package mediapipe;
+
+import "mediapipe/framework/calculator.proto";
+
+message PyTorchTensorsToClassificationCalculatorOptions {
+ extend .mediapipe.CalculatorOptions {
+ optional PyTorchTensorsToClassificationCalculatorOptions ext = 266379463;
+ }
+
+ // Score threshold for perserving the class.
+ optional float min_score_threshold = 1;
+ // Number of highest scoring labels to output. If top_k is not positive then
+ // all labels are used.
+ optional int32 top_k = 2;
+ // Path to a label map file for getting the actual name of class ids.
+ optional string label_map_path = 3;
+}
diff --git a/mediapipe/examples/desktop/object_classification/BUILD b/mediapipe/examples/desktop/object_classification/BUILD
new file mode 100644
index 000000000..3571cfa8a
--- /dev/null
+++ b/mediapipe/examples/desktop/object_classification/BUILD
@@ -0,0 +1,27 @@
+# Copyright 2020 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_cc//cc:defs.bzl", "cc_binary")
+
+licenses(["notice"])
+
+package(default_visibility = ["//mediapipe/examples:__subpackages__"])
+
+cc_binary(
+ name = "object_classification_pytorch_cpu",
+ deps = [
+ "//mediapipe/examples/desktop:demo_run_graph_main",
+ "//mediapipe/graphs/object_classification:desktop_pytorch_calculators",
+ ],
+)
diff --git a/mediapipe/graphs/object_classification/BUILD b/mediapipe/graphs/object_classification/BUILD
new file mode 100644
index 000000000..dc1552659
--- /dev/null
+++ b/mediapipe/graphs/object_classification/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2020 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_cc//cc:defs.bzl", "cc_library")
+
+licenses(["notice"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "desktop_pytorch_calculators",
+ deps = [
+ "//mediapipe/calculators/core:flow_limiter_calculator",
+ "//mediapipe/calculators/image:image_transformation_calculator",
+ "//mediapipe/calculators/pytorch:pytorch_converter_calculator",
+ "//mediapipe/calculators/pytorch:pytorch_inference_calculator",
+ "//mediapipe/calculators/pytorch:pytorch_tensors_to_classification_calculator",
+ "//mediapipe/calculators/util:annotation_overlay_calculator",
+ "//mediapipe/calculators/util:labels_to_render_data_calculator",
+ ],
+)
diff --git a/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt b/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt
new file mode 100644
index 000000000..f4b1524e7
--- /dev/null
+++ b/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt
@@ -0,0 +1,120 @@
+# MediaPipe graph that performs object classification with PyTorch on CPU.
+# Used in the examples in
+# mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu.
+
+# Images on CPU coming into and out of the graph.
+input_stream: "input_video"
+output_stream: "output_video"
+
+# Throttles the images flowing downstream for flow control. It passes through
+# the very first incoming image unaltered, and waits for
+# PyTorchTensorsToClassificationCalculator downstream in the graph to finish
+# generating the corresponding classifications before it passes through another
+# image. All images that come in while waiting are dropped, limiting the number
+# of in-flight images between this calculator and
+# PyTorchTensorsToClassificationCalculator to 1. This prevents the nodes in between
+# from queuing up incoming images and data excessively, which leads to increased
+# latency and memory usage, unwanted in real-time mobile applications. It also
+# eliminates unnecessarily computation, e.g., a transformed image produced by
+# ImageTransformationCalculator may get dropped downstream if the subsequent
+# PyTorchConverterCalculator or PyTorchInferenceCalculator is still busy
+# processing previous inputs.
+node {
+ calculator: "FlowLimiterCalculator"
+ input_stream: "input_video"
+ input_stream: "FINISHED:output_classifications"
+ input_stream_info: {
+ tag_index: "FINISHED"
+ back_edge: true
+ }
+ output_stream: "throttled_input_video"
+}
+
+# Transforms the input image on CPU to a 224x224 image. To scale the input
+# image, the scale_mode option is set to FIT to preserve the aspect ratio,
+# resulting in potential letterboxing in the transformed image.
+node: {
+ calculator: "ImageTransformationCalculator"
+ input_stream: "IMAGE:throttled_input_video"
+ output_stream: "IMAGE:transformed_input_video"
+ output_stream: "LETTERBOX_PADDING:letterbox_padding"
+ node_options: {
+ [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
+ output_width: 224
+ output_height: 224
+ scale_mode: FIT
+ }
+ }
+}
+
+# Converts the transformed input image on CPU into an image tensor stored as a
+# PyTorch tensor. Pixel values are normalized using mean = [0.485, 0.456, 0.406]
+# and std = [0.229, 0.224, 0.225].
+node {
+ calculator: "PyTorchConverterCalculator"
+ input_stream: "IMAGE:transformed_input_video"
+ output_stream: "TENSORS:image_tensor"
+ node_options: {
+ [type.googleapis.com/mediapipe.PyTorchConverterCalculatorOptions] {
+ per_channel_normalizations: {sub:0.485 div:0.229}
+ per_channel_normalizations: {sub:0.456 div:0.224}
+ per_channel_normalizations: {sub:0.406 div:0.225}
+ }
+ }
+}
+
+# Runs a PyTorch model on CPU that takes an image tensor and outputs a
+# vector of tensors representing, for instance, classification scores.
+node {
+ calculator: "PyTorchInferenceCalculator"
+ input_stream: "TENSORS:image_tensor"
+ output_stream: "TENSORS:classification_tensors"
+ node_options: {
+ [type.googleapis.com/mediapipe.PyTorchInferenceCalculatorOptions] {
+ model_path: "mediapipe/models/mobilenetv2.pt"
+ }
+ }
+}
+
+# Decodes the classifications tensors generated by the PyTorch model, based on
+# the specification in the options, into a vector of classifications.
+# Maps classification label IDs to the corresponding label text. The label map is
+# provided in the label_map_path option.
+node {
+ calculator: "PyTorchTensorsToClassificationCalculator"
+ input_stream: "TENSORS:classification_tensors"
+ output_stream: "CLASSIFICATION_LIST:output_classifications"
+ node_options: {
+ [type.googleapis.com/mediapipe.PyTorchTensorsToClassificationCalculatorOptions] {
+ top_k: 3
+ min_score_threshold: 0.1
+ label_map_path: "mediapipe/models/mobilenetv2.labelmap"
+ }
+ }
+}
+
+# Converts the classifications label to drawing primitives for annotation overlay.
+node {
+ calculator: "LabelsToRenderDataCalculator"
+ input_stream: "CLASSIFICATIONS:output_classifications"
+ output_stream: "RENDER_DATA:render_data"
+ node_options: {
+ [type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions] {
+ color { r: 255 g: 0 b: 0 }
+ color { r: 0 g: 255 b: 0 }
+ color { r: 0 g: 0 b: 255 }
+ thickness: 2.0
+ font_height_px: 20
+ font_face: 1
+ location: TOP_LEFT
+ }
+ }
+}
+
+# Draws annotations and overlays them on top of the input images.
+node {
+ calculator: "AnnotationOverlayCalculator"
+ input_stream: "IMAGE:throttled_input_video"
+ input_stream: "render_data"
+ output_stream: "IMAGE:output_video"
+}
diff --git a/mediapipe/models/.gitignore b/mediapipe/models/.gitignore
new file mode 100644
index 000000000..5f849ac87
--- /dev/null
+++ b/mediapipe/models/.gitignore
@@ -0,0 +1 @@
+/mobilenetv2.pt
diff --git a/mediapipe/models/mobilenetv2.labelmap b/mediapipe/models/mobilenetv2.labelmap
new file mode 100644
index 000000000..a509c0074
--- /dev/null
+++ b/mediapipe/models/mobilenetv2.labelmap
@@ -0,0 +1,1000 @@
+tench, Tinca tinca
+goldfish, Carassius auratus
+great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
+tiger shark, Galeocerdo cuvieri
+hammerhead, hammerhead shark
+electric ray, crampfish, numbfish, torpedo
+stingray
+cock
+hen
+ostrich, Struthio camelus
+brambling, Fringilla montifringilla
+goldfinch, Carduelis carduelis
+house finch, linnet, Carpodacus mexicanus
+junco, snowbird
+indigo bunting, indigo finch, indigo bird, Passerina cyanea
+robin, American robin, Turdus migratorius
+bulbul
+jay
+magpie
+chickadee
+water ouzel, dipper
+kite
+bald eagle, American eagle, Haliaeetus leucocephalus
+vulture
+great grey owl, great gray owl, Strix nebulosa
+European fire salamander, Salamandra salamandra
+common newt, Triturus vulgaris
+eft
+spotted salamander, Ambystoma maculatum
+axolotl, mud puppy, Ambystoma mexicanum
+bullfrog, Rana catesbeiana
+tree frog, tree-frog
+tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
+loggerhead, loggerhead turtle, Caretta caretta
+leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
+mud turtle
+terrapin
+box turtle, box tortoise
+banded gecko
+common iguana, iguana, Iguana iguana
+American chameleon, anole, Anolis carolinensis
+whiptail, whiptail lizard
+agama
+frilled lizard, Chlamydosaurus kingi
+alligator lizard
+Gila monster, Heloderma suspectum
+green lizard, Lacerta viridis
+African chameleon, Chamaeleo chamaeleon
+Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
+African crocodile, Nile crocodile, Crocodylus niloticus
+American alligator, Alligator mississipiensis
+triceratops
+thunder snake, worm snake, Carphophis amoenus
+ringneck snake, ring-necked snake, ring snake
+hognose snake, puff adder, sand viper
+green snake, grass snake
+king snake, kingsnake
+garter snake, grass snake
+water snake
+vine snake
+night snake, Hypsiglena torquata
+boa constrictor, Constrictor constrictor
+rock python, rock snake, Python sebae
+Indian cobra, Naja naja
+green mamba
+sea snake
+horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
+diamondback, diamondback rattlesnake, Crotalus adamanteus
+sidewinder, horned rattlesnake, Crotalus cerastes
+trilobite
+harvestman, daddy longlegs, Phalangium opilio
+scorpion
+black and gold garden spider, Argiope aurantia
+barn spider, Araneus cavaticus
+garden spider, Aranea diademata
+black widow, Latrodectus mactans
+tarantula
+wolf spider, hunting spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse, partridge, Bonasa umbellus
+prairie chicken, prairie grouse, prairie fowl
+peacock
+quail
+partridge
+African grey, African gray, Psittacus erithacus
+macaw
+sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser, Mergus serrator
+goose
+black swan, Cygnus atratus
+tusker
+echidna, spiny anteater, anteater
+platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
+wallaby, brush kangaroo
+koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
+wombat
+jellyfish
+sea anemone, anemone
+brain coral
+flatworm, platyhelminth
+nematode, nematode worm, roundworm
+conch
+snail
+slug
+sea slug, nudibranch
+chiton, coat-of-mail shell, sea cradle, polyplacophore
+chambered nautilus, pearly nautilus, nautilus
+Dungeness crab, Cancer magister
+rock crab, Cancer irroratus
+fiddler crab
+king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
+American lobster, Northern lobster, Maine lobster, Homarus americanus
+spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
+crayfish, crawfish, crawdad, crawdaddy
+hermit crab
+isopod
+white stork, Ciconia ciconia
+black stork, Ciconia nigra
+spoonbill
+flamingo
+little blue heron, Egretta caerulea
+American egret, great white heron, Egretta albus
+bittern
+crane
+limpkin, Aramus pictus
+European gallinule, Porphyrio porphyrio
+American coot, marsh hen, mud hen, water hen, Fulica americana
+bustard
+ruddy turnstone, Arenaria interpres
+red-backed sandpiper, dunlin, Erolia alpina
+redshank, Tringa totanus
+dowitcher
+oystercatcher, oyster catcher
+pelican
+king penguin, Aptenodytes patagonica
+albatross, mollymawk
+grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
+killer whale, killer, orca, grampus, sea wolf, Orcinus orca
+dugong, Dugong dugon
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog, Maltese terrier, Maltese
+Pekinese, Pekingese, Peke
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound, Afghan
+basset, basset hound
+beagle
+bloodhound, sleuthhound
+bluetick
+black-and-tan coonhound
+Walker hound, Walker foxhound
+English foxhound
+redbone
+borzoi, Russian wolfhound
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound, Ibizan Podenco
+Norwegian elkhound, elkhound
+otterhound, otter hound
+Saluki, gazelle hound
+Scottish deerhound, deerhound
+Weimaraner
+Staffordshire bullterrier, Staffordshire bull terrier
+American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier, Sealyham
+Airedale, Airedale terrier
+cairn, cairn terrier
+Australian terrier
+Dandie Dinmont, Dandie Dinmont terrier
+Boston bull, Boston terrier
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier, Scottish terrier, Scottie
+Tibetan terrier, chrysanthemum dog
+silky terrier, Sydney silky
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa, Lhasa apso
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla, Hungarian pointer
+English setter
+Irish setter, red setter
+Gordon setter
+Brittany spaniel
+clumber, clumber spaniel
+English springer, English springer spaniel
+Welsh springer spaniel
+cocker spaniel, English cocker spaniel, cocker
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog, bobtail
+Shetland sheepdog, Shetland sheep dog, Shetland
+collie
+Border collie
+Bouvier des Flandres, Bouviers des Flandres
+Rottweiler
+German shepherd, German shepherd dog, German police dog, alsatian
+Doberman, Doberman pinscher
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard, St Bernard
+Eskimo dog, husky
+malamute, malemute, Alaskan malamute
+Siberian husky
+dalmatian, coach dog, carriage dog
+affenpinscher, monkey pinscher, monkey dog
+basenji
+pug, pug-dog
+Leonberg
+Newfoundland, Newfoundland dog
+Great Pyrenees
+Samoyed, Samoyede
+Pomeranian
+chow, chow chow
+keeshond
+Brabancon griffon
+Pembroke, Pembroke Welsh corgi
+Cardigan, Cardigan Welsh corgi
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf, grey wolf, gray wolf, Canis lupus
+white wolf, Arctic wolf, Canis lupus tundrarum
+red wolf, maned wolf, Canis rufus, Canis niger
+coyote, prairie wolf, brush wolf, Canis latrans
+dingo, warrigal, warragal, Canis dingo
+dhole, Cuon alpinus
+African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
+hyena, hyaena
+red fox, Vulpes vulpes
+kit fox, Vulpes macrotis
+Arctic fox, white fox, Alopex lagopus
+grey fox, gray fox, Urocyon cinereoargenteus
+tabby, tabby cat
+tiger cat
+Persian cat
+Siamese cat, Siamese
+Egyptian cat
+cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
+lynx, catamount
+leopard, Panthera pardus
+snow leopard, ounce, Panthera uncia
+jaguar, panther, Panthera onca, Felis onca
+lion, king of beasts, Panthera leo
+tiger, Panthera tigris
+cheetah, chetah, Acinonyx jubatus
+brown bear, bruin, Ursus arctos
+American black bear, black bear, Ursus americanus, Euarctos americanus
+ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
+sloth bear, Melursus ursinus, Ursus ursinus
+mongoose
+meerkat, mierkat
+tiger beetle
+ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
+ground beetle, carabid beetle
+long-horned beetle, longicorn, longicorn beetle
+leaf beetle, chrysomelid
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant, emmet, pismire
+grasshopper, hopper
+cricket
+walking stick, walkingstick, stick insect
+cockroach, roach
+mantis, mantid
+cicada, cicala
+leafhopper
+lacewing, lacewing fly
+dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
+damselfly
+admiral
+ringlet, ringlet butterfly
+monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
+cabbage butterfly
+sulphur butterfly, sulfur butterfly
+lycaenid, lycaenid butterfly
+starfish, sea star
+sea urchin
+sea cucumber, holothurian
+wood rabbit, cottontail, cottontail rabbit
+hare
+Angora, Angora rabbit
+hamster
+porcupine, hedgehog
+fox squirrel, eastern fox squirrel, Sciurus niger
+marmot
+beaver
+guinea pig, Cavia cobaya
+sorrel
+zebra
+hog, pig, grunter, squealer, Sus scrofa
+wild boar, boar, Sus scrofa
+warthog
+hippopotamus, hippo, river horse, Hippopotamus amphibius
+ox
+water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
+bison
+ram, tup
+bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
+ibex, Capra ibex
+hartebeest
+impala, Aepyceros melampus
+gazelle
+Arabian camel, dromedary, Camelus dromedarius
+llama
+weasel
+mink
+polecat, fitch, foulmart, foumart, Mustela putorius
+black-footed ferret, ferret, Mustela nigripes
+otter
+skunk, polecat, wood pussy
+badger
+armadillo
+three-toed sloth, ai, Bradypus tridactylus
+orangutan, orang, orangutang, Pongo pygmaeus
+gorilla, Gorilla gorilla
+chimpanzee, chimp, Pan troglodytes
+gibbon, Hylobates lar
+siamang, Hylobates syndactylus, Symphalangus syndactylus
+guenon, guenon monkey
+patas, hussar monkey, Erythrocebus patas
+baboon
+macaque
+langur
+colobus, colobus monkey
+proboscis monkey, Nasalis larvatus
+marmoset
+capuchin, ringtail, Cebus capucinus
+howler monkey, howler
+titi, titi monkey
+spider monkey, Ateles geoffroyi
+squirrel monkey, Saimiri sciureus
+Madagascar cat, ring-tailed lemur, Lemur catta
+indri, indris, Indri indri, Indri brevicaudatus
+Indian elephant, Elephas maximus
+African elephant, Loxodonta africana
+lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
+giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
+barracouta, snoek
+eel
+coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
+rock beauty, Holocanthus tricolor
+anemone fish
+sturgeon
+gar, garfish, garpike, billfish, Lepisosteus osseus
+lionfish
+puffer, pufferfish, blowfish, globefish
+abacus
+abaya
+academic gown, academic robe, judge's robe
+accordion, piano accordion, squeeze box
+acoustic guitar
+aircraft carrier, carrier, flattop, attack aircraft carrier
+airliner
+airship, dirigible
+altar
+ambulance
+amphibian, amphibious vehicle
+analog clock
+apiary, bee house
+apron
+ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
+assault rifle, assault gun
+backpack, back pack, knapsack, packsack, rucksack, haversack
+bakery, bakeshop, bakehouse
+balance beam, beam
+balloon
+ballpoint, ballpoint pen, ballpen, Biro
+Band Aid
+banjo
+bannister, banister, balustrade, balusters, handrail
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel, cask
+barrow, garden cart, lawn cart, wheelbarrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap, swimming cap
+bath towel
+bathtub, bathing tub, bath, tub
+beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
+beacon, lighthouse, beacon light, pharos
+beaker
+bearskin, busby, shako
+beer bottle
+beer glass
+bell cote, bell cot
+bib
+bicycle-built-for-two, tandem bicycle, tandem
+bikini, two-piece
+binder, ring-binder
+binoculars, field glasses, opera glasses
+birdhouse
+boathouse
+bobsled, bobsleigh, bob
+bolo tie, bolo, bola tie, bola
+bonnet, poke bonnet
+bookcase
+bookshop, bookstore, bookstall
+bottlecap
+bow
+bow tie, bow-tie, bowtie
+brass, memorial tablet, plaque
+brassiere, bra, bandeau
+breakwater, groin, groyne, mole, bulwark, seawall, jetty
+breastplate, aegis, egis
+broom
+bucket, pail
+buckle
+bulletproof vest
+bullet train, bullet
+butcher shop, meat market
+cab, hack, taxi, taxicab
+caldron, cauldron
+candle, taper, wax light
+cannon
+canoe
+can opener, tin opener
+cardigan
+car mirror
+carousel, carrousel, merry-go-round, roundabout, whirligig
+carpenter's kit, tool kit
+carton
+car wheel
+cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello, violoncello
+cellular telephone, cellular phone, cellphone, cell, mobile phone
+chain
+chainlink fence
+chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
+chain saw, chainsaw
+chest
+chiffonier, commode
+chime, bell, gong
+china cabinet, china closet
+Christmas stocking
+church, church building
+cinema, movie theater, movie theatre, movie house, picture palace
+cleaver, meat cleaver, chopper
+cliff dwelling
+cloak
+clog, geta, patten, sabot
+cocktail shaker
+coffee mug
+coffeepot
+coil, spiral, volute, whorl, helix
+combination lock
+computer keyboard, keypad
+confectionery, confectionary, candy store
+container ship, containership, container vessel
+convertible
+corkscrew, bottle screw
+cornet, horn, trumpet, trump
+cowboy boot
+cowboy hat, ten-gallon hat
+cradle
+crane
+crash helmet
+crate
+crib, cot
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam, dike, dyke
+desk
+desktop computer
+dial telephone, dial phone
+diaper, nappy, napkin
+digital clock
+digital watch
+dining table, board
+dishrag, dishcloth
+dishwasher, dish washer, dishwashing machine
+disk brake, disc brake
+dock, dockage, docking facility
+dogsled, dog sled, dog sleigh
+dome
+doormat, welcome mat
+drilling platform, offshore rig
+drum, membranophone, tympan
+drumstick
+dumbbell
+Dutch oven
+electric fan, blower
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa, boa
+file, file cabinet, filing cabinet
+fireboat
+fire engine, fire truck
+fire screen, fireguard
+flagpole, flagstaff
+flute, transverse flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn, horn
+frying pan, frypan, skillet
+fur coat
+garbage truck, dustcart
+gasmask, respirator, gas helmet
+gas pump, gasoline pump, petrol pump, island dispenser
+goblet
+go-kart
+golf ball
+golfcart, golf cart
+gondola
+gong, tam-tam
+gown
+grand piano, grand
+greenhouse, nursery, glasshouse
+grille, radiator grille
+grocery store, grocery, food market, market
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower, blow dryer, blow drier, hair dryer, hair drier
+hand-held computer, hand-held microcomputer
+handkerchief, hankie, hanky, hankey
+hard disc, hard disk, fixed disk
+harmonica, mouth organ, harp, mouth harp
+harp
+harvester, reaper
+hatchet
+holster
+home theater, home theatre
+honeycomb
+hook, claw
+hoopskirt, crinoline
+horizontal bar, high bar
+horse cart, horse-cart
+hourglass
+iPod
+iron, smoothing iron
+jack-o'-lantern
+jean, blue jean, denim
+jeep, landrover
+jersey, T-shirt, tee shirt
+jigsaw puzzle
+jinrikisha, ricksha, rickshaw
+joystick
+kimono
+knee pad
+knot
+lab coat, laboratory coat
+ladle
+lampshade, lamp shade
+laptop, laptop computer
+lawn mower, mower
+lens cap, lens cover
+letter opener, paper knife, paperknife
+library
+lifeboat
+lighter, light, igniter, ignitor
+limousine, limo
+liner, ocean liner
+lipstick, lip rouge
+Loafer
+lotion
+loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
+loupe, jeweler's loupe
+lumbermill, sawmill
+magnetic compass
+mailbag, postbag
+mailbox, letter box
+maillot
+maillot, tank suit
+manhole cover
+maraca
+marimba, xylophone
+mask
+matchstick
+maypole
+maze, labyrinth
+measuring cup
+medicine chest, medicine cabinet
+megalith, megalithic structure
+microphone, mike
+microwave, microwave oven
+military uniform
+milk can
+minibus
+miniskirt, mini
+minivan
+missile
+mitten
+mixing bowl
+mobile home, manufactured home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter, scooter
+mountain bike, all-terrain bike, off-roader
+mountain tent
+mouse, computer mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook, notebook computer
+obelisk
+oboe, hautboy, hautbois
+ocarina, sweet potato
+odometer, hodometer, mileometer, milometer
+oil filter
+organ, pipe organ
+oscilloscope, scope, cathode-ray oscilloscope, CRO
+overskirt
+oxcart
+oxygen mask
+packet
+paddle, boat paddle
+paddlewheel, paddle wheel
+padlock
+paintbrush
+pajama, pyjama, pj's, jammies
+palace
+panpipe, pandean pipe, syrinx
+paper towel
+parachute, chute
+parallel bars, bars
+park bench
+parking meter
+passenger car, coach, carriage
+patio, terrace
+pay-phone, pay-station
+pedestal, plinth, footstall
+pencil box, pencil case
+pencil sharpener
+perfume, essence
+Petri dish
+photocopier
+pick, plectrum, plectron
+pickelhaube
+picket fence, paling
+pickup, pickup truck
+pier
+piggy bank, penny bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate, pirate ship
+pitcher, ewer
+plane, carpenter's plane, woodworking plane
+planetarium
+plastic bag
+plate rack
+plow, plough
+plunger, plumber's helper
+Polaroid camera, Polaroid Land camera
+pole
+police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
+poncho
+pool table, billiard table, snooker table
+pop bottle, soda bottle
+pot, flowerpot
+potter's wheel
+power drill
+prayer rug, prayer mat
+printer
+prison, prison house
+projectile, missile
+projector
+puck, hockey puck
+punching bag, punch bag, punching ball, punchball
+purse
+quill, quill pen
+quilt, comforter, comfort, puff
+racer, race car, racing car
+racket, racquet
+radiator
+radio, wireless
+radio telescope, radio reflector
+rain barrel
+recreational vehicle, RV, R.V.
+reel
+reflex camera
+refrigerator, icebox
+remote control, remote
+restaurant, eating house, eating place, eatery
+revolver, six-gun, six-shooter
+rifle
+rocking chair, rocker
+rotisserie
+rubber eraser, rubber, pencil eraser
+rugby ball
+rule, ruler
+running shoe
+safe
+safety pin
+saltshaker, salt shaker
+sandal
+sarong
+sax, saxophone
+scabbard
+scale, weighing machine
+school bus
+schooner
+scoreboard
+screen, CRT screen
+screw
+screwdriver
+seat belt, seatbelt
+sewing machine
+shield, buckler
+shoe shop, shoe-shop, shoe store
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule, slipstick
+sliding door
+slot, one-armed bandit
+snorkel
+snowmobile
+snowplow, snowplough
+soap dispenser
+soccer ball
+sock
+solar dish, solar collector, solar furnace
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web, spider's web
+spindle
+sports car, sport car
+spotlight, spot
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch, stop watch
+stove
+strainer
+streetcar, tram, tramcar, trolley, trolley car
+stretcher
+studio couch, day bed
+stupa, tope
+submarine, pigboat, sub, U-boat
+suit, suit of clothes
+sundial
+sunglass
+sunglasses, dark glasses, shades
+sunscreen, sunblock, sun blocker
+suspension bridge
+swab, swob, mop
+sweatshirt
+swimming trunks, bathing trunks
+swing
+switch, electric switch, electrical switch
+syringe
+table lamp
+tank, army tank, armored combat vehicle, armoured combat vehicle
+tape player
+teapot
+teddy, teddy bear
+television, television system
+tennis ball
+thatch, thatched roof
+theater curtain, theatre curtain
+thimble
+thresher, thrasher, threshing machine
+throne
+tile roof
+toaster
+tobacco shop, tobacconist shop, tobacconist
+toilet seat
+torch
+totem pole
+tow truck, tow car, wrecker
+toyshop
+tractor
+trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
+tray
+trench coat
+tricycle, trike, velocipede
+trimaran
+tripod
+triumphal arch
+trolleybus, trolley coach, trackless trolley
+trombone
+tub, vat
+turnstile
+typewriter keyboard
+umbrella
+unicycle, monocycle
+upright, upright piano
+vacuum, vacuum cleaner
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin, fiddle
+volleyball
+waffle iron
+wall clock
+wallet, billfold, notecase, pocketbook
+wardrobe, closet, press
+warplane, military plane
+washbasin, handbasin, washbowl, lavabo, wash-hand basin
+washer, automatic washer, washing machine
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool, woolen, woollen
+worm fence, snake fence, snake-rail fence, Virginia fence
+wreck
+yawl
+yurt
+web site, website, internet site, site
+comic book
+crossword puzzle, crossword
+street sign
+traffic light, traffic signal, stoplight
+book jacket, dust cover, dust jacket, dust wrapper
+menu
+plate
+guacamole
+consomme
+hot pot, hotpot
+trifle
+ice cream, icecream
+ice lolly, lolly, lollipop, popsicle
+French loaf
+bagel, beigel
+pretzel
+cheeseburger
+hotdog, hot dog, red hot
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini, courgette
+spaghetti squash
+acorn squash
+butternut squash
+cucumber, cuke
+artichoke, globe artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple, ananas
+banana
+jackfruit, jak, jack
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce, chocolate syrup
+dough
+meat loaf, meatloaf
+pizza, pizza pie
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff, drop, drop-off
+coral reef
+geyser
+lakeside, lakeshore
+promontory, headland, head, foreland
+sandbar, sand bar
+seashore, coast, seacoast, sea-coast
+valley, vale
+volcano
+ballplayer, baseball player
+groom, bridegroom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
+corn
+acorn
+hip, rose hip, rosehip
+buckeye, horse chestnut, conker
+coral fungus
+agaric
+gyromitra
+stinkhorn, carrion fungus
+earthstar
+hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
+bolete
+ear, spike, capitulum
+toilet tissue, toilet paper, bathroom tissue
diff --git a/mediapipe/models/trace_mobilenetv2.py b/mediapipe/models/trace_mobilenetv2.py
new file mode 100644
index 000000000..ff993287d
--- /dev/null
+++ b/mediapipe/models/trace_mobilenetv2.py
@@ -0,0 +1,14 @@
+# `torch` package known to work:
+# python -m pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+import torch
+
+# More about this model at https://pytorch.org/hub/pytorch_vision_mobilenet_v2
+
+if __name__ == '__main__':
+ model = torch.hub.load('pytorch/vision:v0.5.0',
+ 'mobilenet_v2',
+ pretrained=True)
+ model.eval()
+ input_tensor = torch.rand(1, 3, 224, 224)
+ script_model = torch.jit.trace(model, input_tensor)
+ script_model.save("mediapipe/models/mobilenetv2.pt")
diff --git a/third_party/BUILD b/third_party/BUILD
index 4d2676751..89b28818a 100644
--- a/third_party/BUILD
+++ b/third_party/BUILD
@@ -209,6 +209,21 @@ cc_library(
}),
)
+cc_library(
+ name = "libtorch",
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//mediapipe:android_x86": [],
+ "//mediapipe:android_x86_64": [],
+ "//mediapipe:android_armeabi": [],
+ "//mediapipe:android_arm": [],
+ "//mediapipe:android_arm64": [],
+ "//mediapipe:ios": ["@ios_libtorch//:libtorch"],
+ "//mediapipe:macos": ["@macos_libtorch_cpu//:libtorch_cpu"],
+ "//conditions:default": ["@linux_libtorch_cpu//:libtorch_cpu"],
+ }),
+)
+
android_library(
name = "androidx_annotation",
exports = [
diff --git a/third_party/libtorch_linux.BUILD b/third_party/libtorch_linux.BUILD
new file mode 100644
index 000000000..a78651592
--- /dev/null
+++ b/third_party/libtorch_linux.BUILD
@@ -0,0 +1,32 @@
+# Copyright 2020 The MediaPipe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_cc//cc:defs.bzl", "cc_library")
+
+cc_library(
+ name = "libtorch_cpu",
+ srcs = glob([
+ "lib/libc10.so",
+ "lib/libgomp-*.so.1",
+ "lib/libtorch.so",
+ ]),
+ hdrs = glob(["include/**/*.h"]),
+ includes = [
+ "include",
+ "include/TH",
+ "include/THC",
+ "include/torch/csrc/api/include",
+ ],
+ visibility = ["//visibility:public"],
+)