Project import generated by Copybara.

GitOrigin-RevId: 283c1a295de0a53e47d7a94996bda0c52dcfd677
This commit is contained in:
MediaPipe Team 2021-09-13 16:56:21 -07:00 committed by chuoling
parent 6abec128ed
commit 137e1cc763
31 changed files with 2051 additions and 53 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 77 KiB

View File

@ -88,11 +88,11 @@ from [COCO topology](https://cocodataset.org/#keypoints-2020).
Method | Yoga <br/> [`mAP`] | Yoga <br/> [`PCK@0.2`] | Dance <br/> [`mAP`] | Dance <br/> [`PCK@0.2`] | HIIT <br/> [`mAP`] | HIIT <br/> [`PCK@0.2`] Method | Yoga <br/> [`mAP`] | Yoga <br/> [`PCK@0.2`] | Dance <br/> [`mAP`] | Dance <br/> [`PCK@0.2`] | HIIT <br/> [`mAP`] | HIIT <br/> [`PCK@0.2`]
----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------: ----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------:
BlazePose.Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5** BlazePose GHUM Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5**
BlazePose.Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7** BlazePose GHUM Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7**
BlazePose.Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5** BlazePose GHUM Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5**
[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0** [AlphaPose ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0**
[Apple.Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6** [Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6**
![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) | ![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) |
:--------------------------------------------------------------------------: | :--------------------------------------------------------------------------: |
@ -102,10 +102,10 @@ We designed our models specifically for live perception use cases, so all of
them work in real-time on the majority of modern devices. them work in real-time on the majority of modern devices.
Method | Latency <br/> Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency <br/> MacBook Pro (15-inch 2017) Method | Latency <br/> Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency <br/> MacBook Pro (15-inch 2017)
--------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------: -------------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------:
BlazePose.Heavy | 53 ms | 38 ms BlazePose GHUM Heavy | 53 ms | 38 ms
BlazePose.Full | 25 ms | 27 ms BlazePose GHUM Full | 25 ms | 27 ms
BlazePose.Lite | 20 ms | 25 ms BlazePose GHUM Lite | 20 ms | 25 ms
## Models ## Models
@ -237,7 +237,7 @@ pixel respectively. Please refer to the platform-specific usage examples below
for usage details. for usage details.
*Fig 6. Example of MediaPipe Pose segmentation mask.* | *Fig 6. Example of MediaPipe Pose segmentation mask.* |
:-----------------------------------------------------------: | :---------------------------------------------------: |
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
### Python Solution API ### Python Solution API

View File

@ -217,6 +217,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
header->format = video_header.format; header->format = video_header.format;
header->width = video_header.width; header->width = video_header.width;
header->height = video_header.height; header->height = video_header.height;
header->duration = video_header.duration;
header->frame_rate = new_frame_rate; header->frame_rate = new_frame_rate;
cc->Outputs().Index(0).SetHeader(Adopt(header.release())); cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
} else { } else {

View File

@ -356,6 +356,57 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "landmarks_to_tensor_calculator_proto",
srcs = ["landmarks_to_tensor_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "landmarks_to_tensor_calculator",
srcs = ["landmarks_to_tensor_calculator.cc"],
hdrs = ["landmarks_to_tensor_calculator.h"],
copts = select({
"//mediapipe:apple": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":landmarks_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
],
alwayslink = 1,
)
cc_test(
name = "landmarks_to_tensor_calculator_test",
srcs = ["landmarks_to_tensor_calculator_test.cc"],
deps = [
":landmarks_to_tensor_calculator",
":landmarks_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest_main",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "tensors_to_floats_calculator_proto", name = "tensors_to_floats_calculator_proto",
srcs = ["tensors_to_floats_calculator.proto"], srcs = ["tensors_to_floats_calculator.proto"],

View File

@ -99,13 +99,11 @@ class InferenceCalculator : public NodeIntf {
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"}; static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"}; static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr SideInput<std::string>::Optional kNnApiDelegateCacheDir{ static constexpr SideInput<
"NNAPI_CACHE_DIR"}; mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{
static constexpr SideInput<std::string>::Optional kNnApiDelegateModelToken{ "DELEGATE"};
"NNAPI_MODEL_TOKEN"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
kOutTensors, kNnApiDelegateCacheDir, kOutTensors, kDelegate);
kNnApiDelegateModelToken);
protected: protected:
using TfLiteDelegatePtr = using TfLiteDelegatePtr =

View File

@ -18,6 +18,9 @@ package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "InferenceCalculatorProto";
// Full Example: // Full Example:
// //
// node { // node {

View File

@ -50,11 +50,13 @@ int GetXnnpackDefaultNumThreads() {
// Returns number of threads to configure XNNPACK delegate with. // Returns number of threads to configure XNNPACK delegate with.
// Returns user provided value if specified. Otherwise, tries to choose optimal // Returns user provided value if specified. Otherwise, tries to choose optimal
// number of threads depending on the device. // number of threads depending on the device.
int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { int GetXnnpackNumThreads(
const bool opts_has_delegate,
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
static constexpr int kDefaultNumThreads = -1; static constexpr int kDefaultNumThreads = -1;
if (opts.has_delegate() && opts.delegate().has_xnnpack() && if (opts_has_delegate && opts_delegate.has_xnnpack() &&
opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
return opts.delegate().xnnpack().num_threads(); return opts_delegate.xnnpack().num_threads();
} }
return GetXnnpackDefaultNumThreads(); return GetXnnpackDefaultNumThreads();
} }
@ -175,33 +177,40 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
const auto& calculator_opts = const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>(); cc->Options<mediapipe::InferenceCalculatorOptions>();
if (calculator_opts.has_delegate() && auto opts_delegate = calculator_opts.delegate();
calculator_opts.delegate().has_tflite()) { if (!kDelegate(cc).IsEmpty()) {
mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
kDelegate(cc).Get();
CHECK(input_side_packet_delegate.has_tflite() ||
input_side_packet_delegate.has_xnnpack() ||
input_side_packet_delegate.has_nnapi() ||
input_side_packet_delegate.delegate_case() ==
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
<< "inference_calculator_cpu only supports delegate input side packet "
<< "for TFLite, XNNPack and Nnapi";
opts_delegate.MergeFrom(input_side_packet_delegate);
}
const bool opts_has_delegate =
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
if (opts_has_delegate && opts_delegate.has_tflite()) {
// Default tflite inference requeqsted - no need to modify graph. // Default tflite inference requeqsted - no need to modify graph.
return absl::OkStatus(); return absl::OkStatus();
} }
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID)
const bool nnapi_requested = calculator_opts.has_delegate() const bool nnapi_requested = opts_has_delegate ? opts_delegate.has_nnapi()
? calculator_opts.delegate().has_nnapi()
: calculator_opts.use_nnapi(); : calculator_opts.use_nnapi();
if (nnapi_requested) { if (nnapi_requested) {
// Attempt to use NNAPI. // Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used. // If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1); interpreter_->SetAllowFp16PrecisionForFp32(1);
tflite::StatefulNnApiDelegate::Options options; tflite::StatefulNnApiDelegate::Options options;
const auto& nnapi = calculator_opts.delegate().nnapi(); const auto& nnapi = opts_delegate.nnapi();
// Set up cache_dir and model_token for NNAPI compilation cache. // Set up cache_dir and model_token for NNAPI compilation cache.
options.cache_dir = options.cache_dir =
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr; nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
if (!kNnApiDelegateCacheDir(cc).IsEmpty()) {
options.cache_dir = kNnApiDelegateCacheDir(cc).Get().c_str();
}
options.model_token = options.model_token =
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr; nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
if (!kNnApiDelegateModelToken(cc).IsEmpty()) {
options.model_token = kNnApiDelegateModelToken(cc).Get().c_str();
}
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {}); [](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
@ -213,13 +222,13 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
#if defined(__EMSCRIPTEN__) #if defined(__EMSCRIPTEN__)
const bool use_xnnpack = true; const bool use_xnnpack = true;
#else #else
const bool use_xnnpack = calculator_opts.has_delegate() && const bool use_xnnpack = opts_has_delegate && opts_delegate.has_xnnpack();
calculator_opts.delegate().has_xnnpack();
#endif // defined(__EMSCRIPTEN__) #endif // defined(__EMSCRIPTEN__)
if (use_xnnpack) { if (use_xnnpack) {
TfLiteXNNPackDelegateOptions xnnpack_opts{}; TfLiteXNNPackDelegateOptions xnnpack_opts{};
xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); xnnpack_opts.num_threads =
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
&TfLiteXNNPackDelegateDelete); &TfLiteXNNPackDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),

View File

@ -95,19 +95,30 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
use_advanced_gpu_api_ = options.has_delegate() && mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate();
options.delegate().has_gpu() && if (!kDelegate(cc).IsEmpty()) {
options.delegate().gpu().use_advanced_gpu_api(); mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); kDelegate(cc).Get();
tflite_gpu_runner_api_ = options.delegate().gpu().api(); CHECK(input_side_packet_delegate.has_gpu() ||
tflite_gpu_runner_usage_ = options.delegate().gpu().usage(); input_side_packet_delegate.delegate_case() ==
use_kernel_caching_ = use_advanced_gpu_api_ && mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
options.delegate().gpu().has_cached_kernel_path(); << "inference_calculator_gl only supports delegate input side packet "
<< "for Gpu";
delegate.MergeFrom(input_side_packet_delegate);
}
const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty();
use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() &&
delegate.gpu().use_advanced_gpu_api();
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
tflite_gpu_runner_api_ = delegate.gpu().api();
tflite_gpu_runner_usage_ = delegate.gpu().usage();
use_kernel_caching_ =
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
use_gpu_delegate_ = !use_advanced_gpu_api_; use_gpu_delegate_ = !use_advanced_gpu_api_;
if (use_kernel_caching_) { if (use_kernel_caching_) {
#ifdef MEDIAPIPE_ANDROID #ifdef MEDIAPIPE_ANDROID
cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
mediapipe::File::Basename(options.model_path()) + mediapipe::File::Basename(options.model_path()) +
".ker"; ".ker";
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID

View File

@ -0,0 +1,101 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h"
#include <memory>
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
namespace api2 {
namespace {
float GetAttribute(
const Landmark& landmark,
const LandmarksToTensorCalculatorOptions::Attribute& attribute) {
switch (attribute) {
case LandmarksToTensorCalculatorOptions::X:
return landmark.x();
case LandmarksToTensorCalculatorOptions::Y:
return landmark.y();
case LandmarksToTensorCalculatorOptions::Z:
return landmark.z();
case LandmarksToTensorCalculatorOptions::VISIBILITY:
return landmark.visibility();
case LandmarksToTensorCalculatorOptions::PRESENCE:
return landmark.presence();
}
}
} // namespace
class LandmarksToTensorCalculatorImpl
: public NodeImpl<LandmarksToTensorCalculator> {
public:
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<LandmarksToTensorCalculatorOptions>();
RET_CHECK(options_.attributes_size() > 0)
<< "At least one attribute must be specified";
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (kInLandmarkList(cc).IsEmpty()) {
return absl::OkStatus();
}
// Get input landmarks.
const auto& in_landmarks = *kInLandmarkList(cc);
// Determine tensor shape.
const int n_landmarks = in_landmarks.landmark_size();
const int n_attributes = options_.attributes_size();
auto tensor_shape = options_.flatten()
? Tensor::Shape{1, n_landmarks * n_attributes}
: Tensor::Shape{1, n_landmarks, n_attributes};
// Create empty tesnor.
Tensor tensor(Tensor::ElementType::kFloat32, tensor_shape);
auto* buffer = tensor.GetCpuWriteView().buffer<float>();
// Fill tensor with landmark attributes.
for (int i = 0; i < n_landmarks; ++i) {
for (int j = 0; j < n_attributes; ++j) {
buffer[i * n_attributes + j] =
GetAttribute(in_landmarks.landmark(i), options_.attributes(j));
}
}
// Return vector with a single tensor.
auto result = std::vector<Tensor>();
result.push_back(std::move(tensor));
kOutTensors(cc).Send(std::move(result));
return absl::OkStatus();
}
private:
LandmarksToTensorCalculatorOptions options_;
};
MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksToTensorCalculatorImpl);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,61 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_
#include <memory>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
namespace api2 {
// A calculator for converting landmars into a Tensor.
//
// Input:
// LANDMARKS - LandmarkList
// Landmarks to be converted into a Tensor.
//
// Output:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor populated with landmark values.
//
// Example:
// node {
// calculator: "LandmarksToTensorCalculator"
// input_stream: "LANDMARKS:landmarks"
// output_stream: "TENSORS:tensors"
// options: {
// [mediapipe.LandmarksToTensorCalculatorOptions.ext] {
// attributes: [X, Y, Z, VISIBILITY, PRESENCE]
// # flatten: true
// }
// }
// }
class LandmarksToTensorCalculator : public NodeIntf {
public:
static constexpr Input<LandmarkList>::Optional kInLandmarkList{"LANDMARKS"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
MEDIAPIPE_NODE_INTERFACE(LandmarksToTensorCalculator, kInLandmarkList,
kOutTensors);
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_

View File

@ -0,0 +1,44 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The option proto for the LandmarksToTensorCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message LandmarksToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional LandmarksToTensorCalculatorOptions ext = 394810235;
}
enum Attribute {
X = 0;
Y = 1;
Z = 2;
VISIBILITY = 3;
PRESENCE = 4;
}
// Subset and order of attributes as they should appear in the output Tensor.
// Should contain at least one attribute.
repeated Attribute attributes = 1;
// Collapses all landmark attributes into a one dimensional tensor (i.e.
// switches from (n_landmarks, n_attributes) to (n_landmarks * n_attributes)
// representation).
optional bool flatten = 2 [default = false];
}

View File

@ -0,0 +1,155 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using Node = ::mediapipe::CalculatorGraphConfig::Node;
void RunLandmarks(mediapipe::CalculatorRunner* runner,
const LandmarkList& landmarks) {
runner->MutableInputs()
->Tag("LANDMARKS")
.packets.push_back(MakePacket<LandmarkList>(landmarks).At(Timestamp(0)));
MP_ASSERT_OK(runner->Run());
}
const Tensor& GetOutputTensor(mediapipe::CalculatorRunner* runner) {
const auto& output_packets = runner->Outputs().Tag("TENSORS").packets;
EXPECT_EQ(output_packets.size(), 1);
const auto& tensors = output_packets[0].Get<std::vector<Tensor>>();
EXPECT_EQ(tensors.size(), 1);
return tensors[0];
}
void ValidateTensor(const Tensor& tensor,
const std::vector<int>& expected_shape,
const std::vector<float>& expected_values) {
EXPECT_EQ(tensor.shape().dims, expected_shape);
EXPECT_EQ(tensor.shape().num_elements(), expected_values.size());
auto* tensor_buffer = tensor.GetCpuReadView().buffer<float>();
const std::vector<float> tensor_values(
tensor_buffer, tensor_buffer + tensor.shape().num_elements());
EXPECT_THAT(tensor_values, testing::ElementsAreArray(expected_values));
}
TEST(LandmarksToTensorCalculatorTest, AllAttributes) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "LandmarksToTensorCalculator"
input_stream: "LANDMARKS:landmarks"
output_stream: "TENSORS:tensors"
options: {
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
attributes: [ X, Y, Z, VISIBILITY, PRESENCE ]
}
}
)pb"));
LandmarkList landmarks;
auto* landmark1 = landmarks.add_landmark();
landmark1->set_x(1.0f);
landmark1->set_y(2.0f);
landmark1->set_z(3.0f);
landmark1->set_visibility(4.0f);
landmark1->set_presence(5.0f);
auto* landmark2 = landmarks.add_landmark();
landmark2->set_x(6.0f);
landmark2->set_y(7.0f);
landmark2->set_z(8.0f);
landmark2->set_visibility(9.0f);
landmark2->set_presence(10.0f);
RunLandmarks(&runner, landmarks);
const auto& tensor = GetOutputTensor(&runner);
ValidateTensor(tensor, /*expected_shape=*/{1, 2, 5}, /*expected_values=*/
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
}
TEST(LandmarksToTensorCalculatorTest, XYZAttributes) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "LandmarksToTensorCalculator"
input_stream: "LANDMARKS:landmarks"
output_stream: "TENSORS:tensors"
options: {
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
attributes: [ X, Y, Z ]
}
}
)pb"));
LandmarkList landmarks;
auto* landmark1 = landmarks.add_landmark();
landmark1->set_x(1.0f);
landmark1->set_y(2.0f);
landmark1->set_z(3.0f);
auto* landmark2 = landmarks.add_landmark();
landmark2->set_x(6.0f);
landmark2->set_y(7.0f);
landmark2->set_z(8.0f);
RunLandmarks(&runner, landmarks);
const auto& tensor = GetOutputTensor(&runner);
ValidateTensor(tensor, /*expected_shape=*/{1, 2, 3}, /*expected_values=*/
{1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f});
}
TEST(LandmarksToTensorCalculatorTest, XYZAttributes_Flatten) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "LandmarksToTensorCalculator"
input_stream: "LANDMARKS:landmarks"
output_stream: "TENSORS:tensors"
options: {
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
attributes: [ X, Y, Z ]
flatten: true
}
}
)pb"));
LandmarkList landmarks;
auto* landmark1 = landmarks.add_landmark();
landmark1->set_x(1.0f);
landmark1->set_y(2.0f);
landmark1->set_z(3.0f);
auto* landmark2 = landmarks.add_landmark();
landmark2->set_x(6.0f);
landmark2->set_y(7.0f);
landmark2->set_z(8.0f);
RunLandmarks(&runner, landmarks);
const auto& tensor = GetOutputTensor(&runner);
ValidateTensor(tensor, /*expected_shape=*/{1, 6}, /*expected_values=*/
{1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f});
}
} // namespace
} // namespace mediapipe

View File

@ -57,6 +57,16 @@ mediapipe_proto_library(
], ],
) )
mediapipe_proto_library(
name = "filter_detections_calculator_proto",
srcs = ["filter_detections_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "timed_box_list_id_to_label_calculator_proto", name = "timed_box_list_id_to_label_calculator_proto",
srcs = ["timed_box_list_id_to_label_calculator.proto"], srcs = ["timed_box_list_id_to_label_calculator.proto"],
@ -158,6 +168,21 @@ cc_test(
], ],
) )
cc_test(
name = "filter_detections_calculator_test",
size = "small",
srcs = ["filter_detections_calculator_test.cc"],
deps = [
":filter_detections_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library( cc_library(
name = "packet_latency_calculator", name = "packet_latency_calculator",
srcs = ["packet_latency_calculator.cc"], srcs = ["packet_latency_calculator.cc"],
@ -372,6 +397,20 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "filter_detections_calculator",
srcs = ["filter_detections_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":filter_detections_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "landmarks_to_detection_calculator", name = "landmarks_to_detection_calculator",
srcs = ["landmarks_to_detection_calculator.cc"], srcs = ["landmarks_to_detection_calculator.cc"],

View File

@ -0,0 +1,81 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "mediapipe/calculators/util/filter_detections_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
const char kInputDetectionsTag[] = "INPUT_DETECTIONS";
const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS";
//
// Calculator to filter out detections that do not meet the criteria specified
// in options.
//
class FilterDetectionsCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kInputDetectionsTag));
RET_CHECK(cc->Outputs().HasTag(kOutputDetectionsTag));
cc->Inputs().Tag(kInputDetectionsTag).Set<std::vector<Detection>>();
cc->Outputs().Tag(kOutputDetectionsTag).Set<std::vector<Detection>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<mediapipe::FilterDetectionsCalculatorOptions>();
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
const auto& input_detections =
cc->Inputs().Tag(kInputDetectionsTag).Get<std::vector<Detection>>();
auto output_detections = absl::make_unique<std::vector<Detection>>();
for (const Detection& detection : input_detections) {
RET_CHECK_GT(detection.score_size(), 0);
// Note: only score at index 0 supported.
if (detection.score(0) >= options_.min_score()) {
output_detections->push_back(detection);
}
}
cc->Outputs()
.Tag(kOutputDetectionsTag)
.Add(output_detections.release(), cc->InputTimestamp());
return absl::OkStatus();
}
private:
mediapipe::FilterDetectionsCalculatorOptions options_;
};
REGISTER_CALCULATOR(FilterDetectionsCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message FilterDetectionsCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional FilterDetectionsCalculatorOptions ext = 395478132;
}
// Detections lower than this score get filtered out.
optional float min_score = 1;
}

View File

@ -0,0 +1,100 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::testing::ElementsAre;
absl::Status RunGraph(std::vector<Detection>& input_detections,
std::vector<Detection>* output_detections) {
CalculatorRunner runner(R"pb(
calculator: "FilterDetectionsCalculator"
input_stream: "INPUT_DETECTIONS:input_detections"
output_stream: "OUTPUT_DETECTIONS:output_detections"
options {
[mediapipe.FilterDetectionsCalculatorOptions.ext] { min_score: 0.5 }
}
)pb");
const Timestamp input_timestamp = Timestamp(0);
runner.MutableInputs()
->Tag("INPUT_DETECTIONS")
.packets.push_back(MakePacket<std::vector<Detection>>(input_detections)
.At(input_timestamp));
MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed.";
const std::vector<Packet>& output_packets =
runner.Outputs().Tag("OUTPUT_DETECTIONS").packets;
RET_CHECK_EQ(output_packets.size(), 1);
*output_detections = output_packets[0].Get<std::vector<Detection>>();
return absl::OkStatus();
}
TEST(FilterDetectionsCalculatorTest, TestFilterDetections) {
std::vector<Detection> input_detections;
Detection d1, d2;
d1.add_score(0.2);
d2.add_score(0.8);
input_detections.push_back(d1);
input_detections.push_back(d2);
std::vector<Detection> output_detections;
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2)));
}
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) {
std::vector<Detection> input_detections;
Detection d1, d2, d3, d4;
d1.add_score(0.3);
d2.add_score(0.4);
d3.add_score(0.5);
d4.add_score(0.6);
input_detections.push_back(d1);
input_detections.push_back(d2);
input_detections.push_back(d3);
input_detections.push_back(d4);
std::vector<Detection> output_detections;
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3),
mediapipe::EqualsProto(d4)));
}
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) {
std::vector<Detection> input_detections;
std::vector<Detection> output_detections;
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
EXPECT_EQ(output_detections.size(), 0);
}
} // namespace
} // namespace mediapipe

View File

@ -25,7 +25,7 @@ node: {
output_stream: "input_video_cpu" output_stream: "input_video_cpu"
} }
# Transforms the input image on CPU to a 480x640 image. # Scale the image's longer side to 640, keeping aspect ratio.
node: { node: {
calculator: "ImageTransformationCalculator" calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:input_video_cpu" input_stream: "IMAGE:input_video_cpu"

View File

@ -48,5 +48,14 @@ std::string ClassRegistry::GetMethodName(std::string cls, std::string method) {
return method; return method;
} }
std::string ClassRegistry::GetFieldName(std::string cls, std::string field) {
std::string key = absl::StrFormat("%s##%s", cls, field);
auto match = renaming_map_.find(key);
if (match != renaming_map_.end()) {
return match->second;
}
return field;
}
} // namespace android } // namespace android
} // namespace mediapipe } // namespace mediapipe

View File

@ -33,6 +33,7 @@ class ClassRegistry {
absl::node_hash_map<std::string, std::string> renaming_map); absl::node_hash_map<std::string, std::string> renaming_map);
std::string GetClassName(std::string cls); std::string GetClassName(std::string cls);
std::string GetMethodName(std::string cls, std::string method); std::string GetMethodName(std::string cls, std::string method);
std::string GetFieldName(std::string cls, std::string field);
// TODO: Just have the prefix instead of all these constants. // TODO: Just have the prefix instead of all these constants.
static constexpr char const* kAndroidAssetUtilClassName = static constexpr char const* kAndroidAssetUtilClassName =
@ -59,6 +60,8 @@ class ClassRegistry {
"com/google/mediapipe/framework/PacketGetter"; "com/google/mediapipe/framework/PacketGetter";
static constexpr char const* kPacketWithHeaderCallbackClassName = static constexpr char const* kPacketWithHeaderCallbackClassName =
"com/google/mediapipe/framework/PacketWithHeaderCallback"; "com/google/mediapipe/framework/PacketWithHeaderCallback";
static constexpr char const* kProtoUtilSerializedMessageClassName =
"com/google/mediapipe/framework/ProtoUtil$SerializedMessage";
private: private:
ClassRegistry(); ClassRegistry();

View File

@ -156,10 +156,20 @@ bool ThrowIfError(JNIEnv* env, absl::Status status) {
} }
SerializedMessageIds::SerializedMessageIds(JNIEnv* env, jobject data) { SerializedMessageIds::SerializedMessageIds(JNIEnv* env, jobject data) {
jclass j_class = reinterpret_cast<jclass>(env->NewGlobalRef(env->FindClass( auto& class_registry = mediapipe::android::ClassRegistry::GetInstance();
"com/google/mediapipe/framework/ProtoUtil$SerializedMessage"))); std::string serialized_message(
type_name_id = env->GetFieldID(j_class, "typeName", "Ljava/lang/String;"); mediapipe::android::ClassRegistry::kProtoUtilSerializedMessageClassName);
value_id = env->GetFieldID(j_class, "value", "[B"); std::string serialized_message_obfuscated =
class_registry.GetClassName(serialized_message);
std::string type_name_obfuscated =
class_registry.GetFieldName(serialized_message, "typeName");
std::string value_obfuscated =
class_registry.GetFieldName(serialized_message, "value");
jclass j_class = reinterpret_cast<jclass>(
env->NewGlobalRef(env->FindClass(serialized_message_obfuscated.c_str())));
type_name_id = env->GetFieldID(j_class, type_name_obfuscated.c_str(),
"Ljava/lang/String;");
value_id = env->GetFieldID(j_class, value_obfuscated.c_str(), "[B");
} }
} // namespace android } // namespace android

View File

@ -225,6 +225,12 @@ void RegisterPacketCreatorNatives(JNIEnv *env) {
AddJNINativeMethod(&packet_creator_methods, packet_creator, AddJNINativeMethod(&packet_creator_methods, packet_creator,
"nativeCreateString", "(JLjava/lang/String;)J", "nativeCreateString", "(JLjava/lang/String;)J",
(void *)&PACKET_CREATOR_METHOD(nativeCreateString)); (void *)&PACKET_CREATOR_METHOD(nativeCreateString));
std::string serialized_message_name = class_registry.GetClassName(
mediapipe::android::ClassRegistry::kProtoUtilSerializedMessageClassName);
AddJNINativeMethod(&packet_creator_methods, packet_creator,
"nativeCreateProto",
"(JL" + serialized_message_name + ";)J",
(void *)&PACKET_CREATOR_METHOD(nativeCreateProto));
RegisterNativesVector(env, packet_creator_class, packet_creator_methods); RegisterNativesVector(env, packet_creator_class, packet_creator_methods);
} }

View File

@ -119,6 +119,18 @@ public class ImageSolutionBase extends SolutionBase {
"Receving a frame with invalid timestamp.")); "Receving a frame with invalid timestamp."));
return; return;
} }
if (!solutionGraphStarted.get()) {
if (imageObj instanceof TextureFrame) {
((TextureFrame) imageObj).release();
}
throwException(
"The solution graph hasn't been successfully started or error occurs during graph"
+ " initializaton.",
new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Graph is not started."));
return;
}
lastTimestamp = timestamp; lastTimestamp = timestamp;
Packet imagePacket = null; Packet imagePacket = null;
try { try {

View File

@ -87,6 +87,7 @@ public class SolutionBase {
} else { } else {
Log.e(TAG, message, e); Log.e(TAG, message, e);
} }
throw e;
} }
/** /**

BIN
mediapipe/modules/hand_landmark/hand_landmark.tflite Normal file → Executable file

Binary file not shown.

View File

@ -92,11 +92,11 @@ class SolveEpnpTest : public Test {
const float scale = output_3d_points[0].z() / expected_3d_points_[0].z(); const float scale = output_3d_points[0].z() / expected_3d_points_[0].z();
for (int i = 0; i < kNumKeypoints; ++i) { for (int i = 0; i < kNumKeypoints; ++i) {
EXPECT_NEAR(output_3d_points[i].x(), expected_3d_points_[i].x() * scale, EXPECT_NEAR(output_3d_points[i].x(), expected_3d_points_[i].x() * scale,
1.e-6f); 2.e-6f);
EXPECT_NEAR(output_3d_points[i].y(), expected_3d_points_[i].y() * scale, EXPECT_NEAR(output_3d_points[i].y(), expected_3d_points_[i].y() * scale,
1.e-6f); 2.e-6f);
EXPECT_NEAR(output_3d_points[i].z(), expected_3d_points_[i].z() * scale, EXPECT_NEAR(output_3d_points[i].z(), expected_3d_points_[i].z() * scale,
1.e-6f); 2.e-6f);
} }
} }

View File

@ -0,0 +1,555 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/schema/schema_generated.h"
using ::tflite::gpu::BHWC;
using ::tflite::gpu::float2;
using ::tflite::gpu::float3;
using ::tflite::gpu::int2;
using ::tflite::gpu::int3;
using ::tflite::gpu::LandmarksToTransformMatrixV1Attributes;
using ::tflite::gpu::LandmarksToTransformMatrixV2Attributes;
using ::tflite::GetInput;
using ::tflite::GetOutput;
using ::tflite::GetTensorData;
using ::tflite::GetTensorShape;
using ::tflite::NumDimensions;
using ::tflite::NumInputs;
using ::tflite::NumOutputs;
using ::tflite::RuntimeShape;
namespace mediapipe {
namespace tflite_operations {
namespace {
constexpr int kDataInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int3 kTensformMatrixShape(1, 4, 4);
float2 Read3DLandmarkXY(const float* data, int idx) {
float2 result;
result.x = data[idx * 3];
result.y = data[idx * 3 + 1];
return result;
}
float3 Read3DLandmarkXYZ(const float* data, int idx) {
float3 result;
result.x = data[idx * 3];
result.y = data[idx * 3 + 1];
result.z = data[idx * 3 + 2];
return result;
}
struct Mat3 {
Mat3() { data.resize(9); }
Mat3(float x00, float x01, float x02, float x10, float x11, float x12,
float x20, float x21, float x22)
: data{x00, x01, x02, x10, x11, x12, x20, x21, x22} {}
Mat3 operator*(const Mat3& other) {
Mat3 result;
for (int r = 0; r < 3; r++) {
for (int c = 0; c < 3; c++) {
float sum = 0;
for (int k = 0; k < 3; k++) {
sum += this->Get(r, k) * other.Get(k, c);
}
result.Set(r, c, sum);
}
}
return result;
}
float3 operator*(const float3& vec) const {
float3 result;
for (int r = 0; r < 3; r++) {
float sum = 0;
for (int k = 0; k < 3; k++) {
sum += this->Get(r, k) * vec[k];
}
result[r] = sum;
}
return result;
}
float Get(int x, int y) const { return data[x * 3 + y]; }
void Set(int x, int y, float val) { data[x * 3 + y] = val; }
std::vector<float> data;
};
struct Mat4 {
Mat4() { data.resize(16); }
Mat4(float x00, float x01, float x02, float x03, float x10, float x11,
float x12, float x13, float x20, float x21, float x22, float x23,
float x30, float x31, float x32, float x33)
: data{x00, x01, x02, x03, x10, x11, x12, x13,
x20, x21, x22, x23, x30, x31, x32, x33} {}
void operator*=(const Mat4& other) {
Mat4 result;
for (int r = 0; r < 4; r++) {
for (int c = 0; c < 4; c++) {
float sum = 0;
for (int k = 0; k < 4; k++) {
sum += this->Get(r, k) * other.Get(k, c);
}
result.Set(r, c, sum);
}
}
std::memcpy(this->data.data(), result.data.data(),
result.data.size() * sizeof(float));
}
float Get(int x, int y) const { return data[x * 4 + y]; }
void Set(int x, int y, float val) { data[x * 4 + y] = val; }
std::vector<float> data;
};
namespace v1 {
inline void LandmarksToTransformMatrixV1(
const LandmarksToTransformMatrixV1Attributes& params,
const RuntimeShape& input0_shape, const float* landmarks,
const RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 3);
TFLITE_CHECK_EQ(input0_shape.Dims(0), 1);
TFLITE_CHECK_EQ(input0_shape.Dims(1), 1);
TFLITE_CHECK_EQ(input0_shape.Dims(2), 1);
float2 left_landmark = Read3DLandmarkXY(landmarks, params.left_rotation_idx);
float2 right_landmark =
Read3DLandmarkXY(landmarks, params.right_rotation_idx);
float alpha = -std::atan((right_landmark.y - left_landmark.y) /
(right_landmark.x - left_landmark.x));
float2 max_value(-100000, -100000);
float2 min_value(100000, 100000);
for (int i = 0; i < params.subset.size(); i++) {
for (int j = 0; j < 2; j++) {
float2 landmark_current =
Read3DLandmarkXY(landmarks, params.subset[i][j]);
float2 rotated(
landmark_current.x * cos(alpha) - landmark_current.y * sin(alpha),
landmark_current.x * sin(alpha) + landmark_current.y * cos(alpha));
max_value = float2(std::max(max_value.x, rotated.x),
std::max(max_value.y, rotated.y));
min_value = float2(std::min(min_value.x, rotated.x),
std::min(min_value.y, rotated.y));
}
}
float2 bbox_size((max_value.x - min_value.x) * params.bbox_size_multiplier,
(max_value.y - min_value.y) * params.bbox_size_multiplier);
Mat3 scale_matrix(
bbox_size.x / params.landmarks_range, 0.0, 0.0, // first row
0.0, bbox_size.y / params.landmarks_range, 0.0, // second row
0.0, 0.0, 1.0); // third row
float2 middle((max_value.x + min_value.x) / 2.0,
(max_value.y + min_value.y) / 2.0);
float2 rotated_middle(middle.x * cos(-alpha) - middle.y * sin(-alpha),
middle.x * sin(-alpha) + middle.y * cos(-alpha));
Mat3 rotation_matrix(
cos(-alpha), -sin(-alpha),
(rotated_middle.x / params.landmarks_range) * 2.0 - 1.0, // first row
sin(-alpha), cos(-alpha),
(rotated_middle.y / params.landmarks_range) * 2.0 - 1.0, // second row
0, 0, 1); // third row
Mat3 to_relative(2.0 / (params.output_hw.w - 1.0), 0.0, -1.0, // first row
0.0, 2.0 / (params.output_hw.h - 1.0), -1.0, // second row
0.0, 0.0, 1.0); // third row
Mat3 to_absolute((params.input_hw.w - 1.0) / 2.0, 0.0,
(params.input_hw.w - 1.0) / 2.0, // first row
0.0, (params.input_hw.h - 1.0) / 2.0,
(params.input_hw.h - 1.0) / 2.0, // second row
0.0, 0.0, 1.0); // third row
// Inverse Transformstion Matrix
Mat3 itm = to_absolute * rotation_matrix * scale_matrix * to_relative;
output_data[0] = itm.Get(0, 0);
output_data[1] = itm.Get(0, 1);
output_data[2] = 0.0;
output_data[3] = itm.Get(0, 2);
output_data[4] = itm.Get(1, 0);
output_data[5] = itm.Get(1, 1);
output_data[6] = 0.0;
output_data[7] = itm.Get(1, 2);
output_data[8] = itm.Get(2, 0);
output_data[9] = itm.Get(2, 1);
output_data[10] = itm.Get(2, 2);
output_data[11] = 0.0;
output_data[12] = 0.0;
output_data[13] = 0.0;
output_data[14] = 0.0;
output_data[15] = 1.0;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = kTensformMatrixShape.x;
output_size->data[1] = kTensformMatrixShape.y;
output_size->data[2] = kTensformMatrixShape.z;
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
LandmarksToTransformMatrixV1Attributes op_params;
BHWC output_shape;
auto status = tflite::gpu::ParseLandmarksToTransformMatrixV1Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.bbox_size_multiplier == 0) {
context->ReportError(context, "Incorrect bbox_size_multiplier: %d",
op_params.bbox_size_multiplier);
return kTfLiteError;
}
if (op_params.dimensions != 3) {
context->ReportError(context, "Incorrect dimensions: %d",
op_params.dimensions);
return kTfLiteError;
}
if (op_params.input_hw.h <= 0 || op_params.input_hw.w <= 0) {
context->ReportError(context, "Incorrect input_hw: h = %d w = %d",
op_params.input_hw.h, op_params.input_hw.w);
return kTfLiteError;
}
if (op_params.output_hw.h <= 0 || op_params.output_hw.w <= 0) {
context->ReportError(context, "Incorrect output_hw: h = %d w = %d",
op_params.output_hw.h, op_params.output_hw.w);
return kTfLiteError;
}
if (op_params.landmarks_range <= 0) {
context->ReportError(context, "Incorrect landmarks_range: %d",
op_params.landmarks_range);
return kTfLiteError;
}
if (op_params.left_rotation_idx < 0) {
context->ReportError(context, "Incorrect left_rotation_idx: %d",
op_params.left_rotation_idx);
return kTfLiteError;
}
if (op_params.right_rotation_idx < 0) {
context->ReportError(context, "Incorrect right_rotation_idx: %d",
op_params.right_rotation_idx);
return kTfLiteError;
}
if (op_params.subset.empty()) {
context->ReportError(context, "Subset parameter is empty");
return kTfLiteError;
}
int counter = 0;
for (auto& val : op_params.subset) {
for (int i = 0; i < 2; i++) {
if (val[i] < 0) {
context->ReportError(context,
"Incorrect subset value: index = %d, value = %d",
counter, val[i]);
return kTfLiteError;
}
counter++;
}
}
const TfLiteTensor* input0 = GetInput(context, node, kDataInputTensor);
TF_LITE_ENSURE(context, input0 != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
LandmarksToTransformMatrixV1(
op_params, GetTensorShape(input0), GetTensorData<float>(input0),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v1
namespace v2 {
void EstimateRotationRadians(const float* input_data_0, int left_rotation_idx,
int right_rotation_idx,
float target_rotation_radians,
float* rotation_radians) {
const float3 left_landmark =
Read3DLandmarkXYZ(input_data_0, left_rotation_idx);
const float3 right_landmark =
Read3DLandmarkXYZ(input_data_0, right_rotation_idx);
const float left_x = left_landmark[0];
const float left_y = left_landmark[1];
const float right_x = right_landmark[0];
const float right_y = right_landmark[1];
float rotation = std::atan2(right_y - left_y, right_x - left_x);
rotation = target_rotation_radians - rotation;
*rotation_radians = rotation;
}
void EstimateCenterAndSize(const float* input_data_0,
std::vector<tflite::gpu::int2> subset_idxs,
float rotation_radians, float* crop_x, float* crop_y,
float* crop_width, float* crop_height) {
std::vector<float3> landmarks;
landmarks.reserve(subset_idxs.size() * 2);
for (int i = 0; i < subset_idxs.size(); i++) {
landmarks.push_back(Read3DLandmarkXYZ(input_data_0, subset_idxs[i][0]));
landmarks.push_back(Read3DLandmarkXYZ(input_data_0, subset_idxs[i][1]));
}
for (int i = 0; i < landmarks.size(); i++) {
landmarks[i].z = 1.0;
}
const float& r = rotation_radians;
// clang-format off
const Mat3 t_rotation = Mat3(std::cos(r), -std::sin(r), 0.0,
std::sin(r), std::cos(r), 0.0,
0.0, 0.0, 1.0);
const Mat3 t_rotation_inverse =
Mat3(std::cos(-r), -std::sin(-r), 0.0,
std::sin(-r), std::cos(-r), 0.0,
0.0, 0.0, 1.0);
// clang-format on
for (int i = 0; i < landmarks.size(); i++) {
landmarks[i] = t_rotation * landmarks[i];
}
float3 xy1_max = landmarks[0], xy1_min = landmarks[0];
for (int i = 1; i < landmarks.size(); i++) {
if (xy1_max.x < landmarks[i].x) xy1_max.x = landmarks[i].x;
if (xy1_max.y < landmarks[i].y) xy1_max.y = landmarks[i].y;
if (xy1_min.x > landmarks[i].x) xy1_min.x = landmarks[i].x;
if (xy1_min.y > landmarks[i].y) xy1_min.y = landmarks[i].y;
}
*crop_width = xy1_max.x - xy1_min.x;
*crop_height = xy1_max.y - xy1_min.y;
float3 crop_xy1 = xy1_min;
crop_xy1.x += xy1_max.x;
crop_xy1.y += xy1_max.y;
crop_xy1.x /= 2;
crop_xy1.y /= 2;
crop_xy1 = t_rotation_inverse * crop_xy1;
*crop_x = crop_xy1.x;
*crop_y = crop_xy1.y;
}
inline void LandmarksToTransformMatrixV2(
const LandmarksToTransformMatrixV2Attributes& params,
const RuntimeShape& input0_shape, const float* landmarks,
const RuntimeShape& output_shape, float* output_data) {
float rotation_radians = 0.0;
EstimateRotationRadians(landmarks, params.left_rotation_idx,
params.right_rotation_idx,
params.target_rotation_radians, &rotation_radians);
float crop_x = 0.0, crop_y = 0.0, crop_width = 0.0, crop_height = 0.0;
EstimateCenterAndSize(landmarks, params.subset_idxs, rotation_radians,
&crop_x, &crop_y, &crop_width, &crop_height);
// Turn off clang formatting to make matrices initialization more readable.
// clang-format off
Mat4 t = Mat4(1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0);
const Mat4 t_shift = Mat4(1.0, 0.0, 0.0, crop_x,
0.0, 1.0, 0.0, crop_y,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0);
t *= t_shift;
const float& r = -rotation_radians;
const Mat4 t_rotation = Mat4(std::cos(r), -std::sin(r), 0.0, 0.0,
std::sin(r), std::cos(r), 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0);
t *= t_rotation;
const float scale_x = params.scale_x * crop_width / params.output_width;
const float scale_y = params.scale_y * crop_height / params.output_height;
const Mat4 t_scale = Mat4(scale_x, 0.0, 0.0, 0.0,
0.0, scale_y, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0);
t *= t_scale;
const float shift_x = -1.0 * (params.output_width / 2.0);
const float shift_y = -1.0 * (params.output_height / 2.0);
const Mat4 t_shift2 = Mat4(1.0, 0.0, 0.0, shift_x,
0.0, 1.0, 0.0, shift_y,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0);
t *= t_shift2;
std::memcpy(output_data, t.data.data(), 16 * sizeof(float));
// clang-format on
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 3);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = kTensformMatrixShape.x;
output_size->data[1] = kTensformMatrixShape.y;
output_size->data[2] = kTensformMatrixShape.z;
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
LandmarksToTransformMatrixV2Attributes op_params;
BHWC output_shape;
auto status = tflite::gpu::ParseLandmarksToTransformMatrixV2Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.left_rotation_idx < 0) {
context->ReportError(context, "Incorrect left_rotation_idx: %d",
op_params.left_rotation_idx);
return kTfLiteError;
}
if (op_params.right_rotation_idx < 0) {
context->ReportError(context, "Incorrect right_rotation_idx: %d",
op_params.right_rotation_idx);
return kTfLiteError;
}
if (op_params.output_height <= 0) {
context->ReportError(context, "Incorrect output_height: %d",
op_params.output_height);
return kTfLiteError;
}
if (op_params.output_width <= 0) {
context->ReportError(context, "Incorrect output_width: %d",
op_params.output_width);
return kTfLiteError;
}
if (op_params.scale_x <= 0) {
context->ReportError(context, "Incorrect scale_x: %d", op_params.scale_x);
return kTfLiteError;
}
if (op_params.scale_y <= 0) {
context->ReportError(context, "Incorrect scale_y: %d", op_params.scale_y);
return kTfLiteError;
}
int counter = 0;
for (auto& val : op_params.subset_idxs) {
for (int i = 0; i < 2; i++) {
if (val[i] < 0) {
context->ReportError(context,
"Incorrect subset value: index = %d, value = %d",
counter, val[i]);
return kTfLiteError;
}
counter++;
}
}
const TfLiteTensor* input0 = GetInput(context, node, kDataInputTensor);
TF_LITE_ENSURE(context, input0 != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
LandmarksToTransformMatrixV2(
op_params, GetTensorShape(input0), GetTensorData<float>(input0),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v2
} // namespace
TfLiteRegistration* RegisterLandmarksToTransformMatrixV1() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v1::Prepare,
/*.invoke=*/v1::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"Landmarks2TransformMatrix",
/*.version=*/1,
};
return &reg;
}
TfLiteRegistration* RegisterLandmarksToTransformMatrixV2() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v2::Prepare,
/*.invoke=*/v2::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"Landmarks2TransformMatrix",
/*.version=*/2,
};
return &reg;
}
} // namespace tflite_operations
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_
#include "tensorflow/lite/kernels/kernel_util.h"
namespace mediapipe {
namespace tflite_operations {
TfLiteRegistration* RegisterLandmarksToTransformMatrixV1();
TfLiteRegistration* RegisterLandmarksToTransformMatrixV2();
} // namespace tflite_operations
} // namespace mediapipe
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_

View File

@ -0,0 +1,298 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tflite_operations {
namespace {
constexpr int kDataInput0Tensor = 0;
constexpr int kDataInput1Tensor = 1;
constexpr int kOutputTensor = 0;
float DotProduct(const tflite::gpu::float4& l, const tflite::gpu::float4& r) {
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
}
namespace v1 {
inline void TransformLandmarks(
const tflite::gpu::TransformLandmarksAttributes& params,
const tflite::RuntimeShape& input0_shape, const float* landmarks,
const tflite::RuntimeShape& input1_shape, const float* transform_matrix,
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
TFLITE_CHECK_EQ(input0_shape.Dims(3) % params.dimensions, 0);
TFLITE_CHECK_NE(params.scale, 0);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(1),
input0_shape.Dims(2),
input0_shape.Dims(3)};
tflite::RuntimeShape output_shape_with_batch{
/*batch=*/1, output_shape.Dims(1), output_shape.Dims(2),
output_shape.Dims(3)};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
transform_matrix[2],
transform_matrix[3] * params.scale);
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
transform_matrix[6],
transform_matrix[7] * params.scale);
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int landmark = 0; landmark < output_channels / params.dimensions;
++landmark) {
const int offset = Offset(output_shape_with_batch, 0, out_y, out_x,
landmark * params.dimensions);
if (params.dimensions == 2) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0),
static_cast<float>(1.0));
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv));
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
}
if (params.dimensions == 3) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0),
static_cast<float>(1.0));
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv), lv.z);
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
output_data[offset + 2] = landmarks[offset + 2];
}
}
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = input->dims->data[0];
output_size->data[1] = input->dims->data[1];
output_size->data[2] = input->dims->data[2];
output_size->data[3] = input->dims->data[3];
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformLandmarksAttributes op_params;
tflite::gpu::BHWC output_shape;
auto status = tflite::gpu::ParseTransformLandmarksV1Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
context->ReportError(context, "Incorrect dimensions size: %d",
op_params.dimensions);
return kTfLiteError;
}
if (op_params.scale == 0) {
context->ReportError(context, "Incorrect scale value: %d", op_params.scale);
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TransformLandmarks(
op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v1
namespace v2 {
inline void TransformLandmarksV2(
const tflite::gpu::TransformLandmarksAttributes& params,
const tflite::RuntimeShape& input0_shape, const float* landmarks,
const float* transform_matrix, // transformation matrix
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 3);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 3);
const int output_width = output_shape.Dims(1);
TFLITE_CHECK_EQ(input0_shape.Dims(2) % params.dimensions, 0);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(0),
input0_shape.Dims(1),
input0_shape.Dims(2)};
tflite::RuntimeShape output_shape_with_batch{
/*batch=*/1, output_shape.Dims(0), output_shape.Dims(1),
output_shape.Dims(2)};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
transform_matrix[2], transform_matrix[3]);
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
transform_matrix[6], transform_matrix[7]);
for (int landmark = 0; landmark < output_width; ++landmark) {
const int offset = Offset(input_shape_with_batch, 0, 0, landmark, 0);
if (params.dimensions == 2) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0), static_cast<float>(1.0));
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv));
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
}
if (params.dimensions == 3) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0), static_cast<float>(1.0));
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv), lv.z);
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
output_data[offset + 2] = landmarks[offset + 2];
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 3);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = input->dims->data[0];
output_size->data[1] = input->dims->data[1];
output_size->data[2] = input->dims->data[2];
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformLandmarksAttributes op_params;
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
tflite::RuntimeShape runtime_output_shape = tflite::GetTensorShape(output);
tflite::gpu::BHWC output_shape(1, runtime_output_shape.Dims(0),
runtime_output_shape.Dims(1),
runtime_output_shape.Dims(2));
auto status = tflite::gpu::ParseTransformLandmarksV2Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
context->ReportError(context, "Incorrect dimensions size: %d",
op_params.dimensions);
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TransformLandmarksV2(op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0),
tflite::GetTensorData<float>(input1),
tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v2
} // namespace
TfLiteRegistration* RegisterTransformLandmarksV1() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v1::Prepare,
/*.invoke=*/v1::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformLandmarks",
/*.version=*/1,
};
return &reg;
}
TfLiteRegistration* RegisterTransformLandmarksV2() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v2::Prepare,
/*.invoke=*/v2::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformLandmarks",
/*.version=*/2,
};
return &reg;
}
} // namespace tflite_operations
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_
#include "tensorflow/lite/kernels/kernel_util.h"
namespace mediapipe {
namespace tflite_operations {
TfLiteRegistration* RegisterTransformLandmarksV1();
TfLiteRegistration* RegisterTransformLandmarksV2();
} // namespace tflite_operations
} // namespace mediapipe
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_

View File

@ -0,0 +1,332 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tflite_operations {
namespace {
constexpr int kDataInput0Tensor = 0;
constexpr int kDataInput1Tensor = 1;
constexpr int kOutputTensor = 0;
float DotProduct(const tflite::gpu::float4& l, const tflite::gpu::float4& r) {
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
}
namespace v1 {
inline void TransformTensor(
const tflite::gpu::TransformTensorBilinearAttributes& params,
const tflite::RuntimeShape& input0_shape,
const float* input_data_0, // data
const tflite::RuntimeShape& input1_shape,
const float* input_data_1, // transformation matrix
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
const int input_height = input0_shape.Dims(1);
const int input_width = input0_shape.Dims(2);
const int input_channels = input0_shape.Dims(3);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input_height,
input_width, input_channels};
tflite::RuntimeShape output_shape_with_batch{/*batch=*/1, output_height,
output_width, output_channels};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(input_data_1[0], input_data_1[1],
input_data_1[2], input_data_1[3]);
tflite::gpu::float4 y_transform(input_data_1[4], input_data_1[5],
input_data_1[6], input_data_1[7]);
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
tflite::gpu::float4 coord(
static_cast<float>(out_x), static_cast<float>(out_y),
static_cast<float>(0.0), static_cast<float>(1.0));
// Transformed coordinates.
tflite::gpu::float2 tc(DotProduct(x_transform, coord),
DotProduct(y_transform, coord));
bool out_of_bound = tc.x < 0.0 || tc.x > input_width - 1 || tc.y < 0.0 ||
tc.y > input_height - 1;
for (int out_z = 0; out_z < output_channels; ++out_z) {
float result = 0;
if (!out_of_bound) {
// Corners position:
// q_11 --- q_21
// ---- ----
// q_12 --- q_22
auto ReadValue = [&](int h, int w) -> float {
return h < 0 || w < 0 || h >= input_height || w >= input_width
? 0
: input_data_0[Offset(input_shape_with_batch, 0, h, w,
out_z)];
};
float q_11 = ReadValue(floor(tc.y), floor(tc.x));
float q_21 = ReadValue(floor(tc.y), floor(tc.x) + 1);
float q_12 = ReadValue(floor(tc.y) + 1, floor(tc.x));
float q_22 = ReadValue(floor(tc.y) + 1, floor(tc.x) + 1);
float right_contrib = tc.x - floor(tc.x);
float lower_contrib = tc.y - floor(tc.y);
float upper = (1.0 - right_contrib) * q_11 + right_contrib * q_21;
float lower = (1.0 - right_contrib) * q_12 + right_contrib * q_22;
result = lower_contrib * lower + (1.0 - lower_contrib) * upper;
}
const int out_offset =
Offset(output_shape_with_batch, 0, out_y, out_x, out_z);
output_data[out_offset] = result;
}
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformTensorBilinearAttributes op_params;
tflite::gpu::BHWC output_shape;
auto status = tflite::gpu::ParseTransformTensorBilinearV1Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TransformTensor(
op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v1
namespace v2 {
inline void TransformTensorBilinearV2(
const tflite::gpu::TransformTensorBilinearAttributes& params,
const tflite::RuntimeShape& input0_shape,
const float* input_data_0, // data
const tflite::RuntimeShape& input1_shape,
const float* input_data_1, // transformation matrix
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
const int input_height = input0_shape.Dims(1);
const int input_width = input0_shape.Dims(2);
const int input_channels = input0_shape.Dims(3);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input_height,
input_width, input_channels};
tflite::RuntimeShape output_shape_with_batch{/*batch=*/1, output_height,
output_width, output_channels};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(input_data_1[0], input_data_1[1],
input_data_1[2], input_data_1[3]);
tflite::gpu::float4 y_transform(input_data_1[4], input_data_1[5],
input_data_1[6], input_data_1[7]);
// Align corners correction: T -> S * ( T * A ), where T is a
// transformation matrix, and subtruction and addition matrices are:
// S A
// 1 0 0 -0.5 1 0 0 0.5
// 0 1 0 -0.5 0 1 0 0.5
// 0 0 1 0 0 0 1 0
// 0 0 0 1 0 0 0 1
// Transformation matrix column 3 and rows 3, 4 are identity, which makes
// the final formula pretty simple and easy to get if doing a manual
// multiuplication.
x_transform[3] += x_transform[0] * 0.5 + x_transform[1] * 0.5 - 0.5;
y_transform[3] += y_transform[0] * 0.5 + y_transform[1] * 0.5 - 0.5;
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
tflite::gpu::float4 coord(
static_cast<float>(out_x), static_cast<float>(out_y),
static_cast<float>(0.0), static_cast<float>(1.0));
// Transformed coordinates.
tflite::gpu::float2 tc(DotProduct(x_transform, coord),
DotProduct(y_transform, coord));
bool out_of_bound = tc.x < 0.0 || tc.x > input_width - 1 || tc.y < 0.0 ||
tc.y > input_height - 1;
for (int out_z = 0; out_z < output_channels; ++out_z) {
float result = 0;
if (!out_of_bound) {
// Corners position:
// q_11 --- q_21
// ---- ----
// q_12 --- q_22
auto ReadValue = [&](int h, int w) -> float {
return h < 0 || w < 0 || h >= input_height || w >= input_width
? 0
: input_data_0[Offset(input_shape_with_batch, 0, h, w,
out_z)];
};
float q_11 = ReadValue(floor(tc.y), floor(tc.x));
float q_21 = ReadValue(floor(tc.y), floor(tc.x) + 1);
float q_12 = ReadValue(floor(tc.y) + 1, floor(tc.x));
float q_22 = ReadValue(floor(tc.y) + 1, floor(tc.x) + 1);
float right_contrib = tc.x - floor(tc.x);
float lower_contrib = tc.y - floor(tc.y);
float upper = (1.0 - right_contrib) * q_11 + right_contrib * q_21;
float lower = (1.0 - right_contrib) * q_12 + right_contrib * q_22;
result = lower_contrib * lower + (1.0 - lower_contrib) * upper;
}
const int out_offset =
Offset(output_shape_with_batch, 0, out_y, out_x, out_z);
output_data[out_offset] = result;
}
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformTensorBilinearAttributes op_params;
tflite::gpu::BHWC output_shape;
auto status = tflite::gpu::ParseTransformTensorBilinearV2Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TransformTensorBilinearV2(
op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v2
} // namespace
TfLiteRegistration* RegisterTransformTensorBilinearV1() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v1::Prepare,
/*.invoke=*/v1::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformTensor",
/*.version=*/1,
};
return &reg;
}
TfLiteRegistration* RegisterTransformTensorBilinearV2() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v2::Prepare,
/*.invoke=*/v2::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformTensorBilinear",
/*.version=*/2,
};
return &reg;
}
} // namespace tflite_operations
} // namespace mediapipe

View File

@ -0,0 +1,30 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_
#include "tensorflow/lite/kernels/kernel_util.h"
namespace mediapipe {
namespace tflite_operations {
TfLiteRegistration* RegisterTransformTensorBilinearV1();
TfLiteRegistration* RegisterTransformTensorBilinearV2();
} // namespace tflite_operations
} // namespace mediapipe
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_