Project import generated by Copybara.
GitOrigin-RevId: 283c1a295de0a53e47d7a94996bda0c52dcfd677
This commit is contained in:
parent
6abec128ed
commit
137e1cc763
Binary file not shown.
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 77 KiB |
|
@ -88,11 +88,11 @@ from [COCO topology](https://cocodataset.org/#keypoints-2020).
|
||||||
|
|
||||||
Method | Yoga <br/> [`mAP`] | Yoga <br/> [`PCK@0.2`] | Dance <br/> [`mAP`] | Dance <br/> [`PCK@0.2`] | HIIT <br/> [`mAP`] | HIIT <br/> [`PCK@0.2`]
|
Method | Yoga <br/> [`mAP`] | Yoga <br/> [`PCK@0.2`] | Dance <br/> [`mAP`] | Dance <br/> [`PCK@0.2`] | HIIT <br/> [`mAP`] | HIIT <br/> [`PCK@0.2`]
|
||||||
----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------:
|
----------------------------------------------------------------------------------------------------- | -----------------: | ---------------------: | ------------------: | ----------------------: | -----------------: | ---------------------:
|
||||||
BlazePose.Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5**
|
BlazePose GHUM Heavy | 68.1 | **96.4** | 73.0 | **97.2** | 74.0 | **97.5**
|
||||||
BlazePose.Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7**
|
BlazePose GHUM Full | 62.6 | **95.5** | 67.4 | **96.3** | 68.0 | **95.7**
|
||||||
BlazePose.Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5**
|
BlazePose GHUM Lite | 45.0 | **90.2** | 53.6 | **92.5** | 53.8 | **93.5**
|
||||||
[AlphaPose.ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0**
|
[AlphaPose ResNet50](https://github.com/MVIG-SJTU/AlphaPose) | 63.4 | **96.0** | 57.8 | **95.5** | 63.4 | **96.0**
|
||||||
[Apple.Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6**
|
[Apple Vision](https://developer.apple.com/documentation/vision/detecting_human_body_poses_in_images) | 32.8 | **82.7** | 36.4 | **91.4** | 44.5 | **88.6**
|
||||||
|
|
||||||
![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) |
|
![pose_tracking_pck_chart.png](../images/mobile/pose_tracking_pck_chart.png) |
|
||||||
:--------------------------------------------------------------------------: |
|
:--------------------------------------------------------------------------: |
|
||||||
|
@ -102,10 +102,10 @@ We designed our models specifically for live perception use cases, so all of
|
||||||
them work in real-time on the majority of modern devices.
|
them work in real-time on the majority of modern devices.
|
||||||
|
|
||||||
Method | Latency <br/> Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency <br/> MacBook Pro (15-inch 2017)
|
Method | Latency <br/> Pixel 3 [TFLite GPU](https://www.tensorflow.org/lite/performance/gpu_advanced) | Latency <br/> MacBook Pro (15-inch 2017)
|
||||||
--------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------:
|
-------------------- | -------------------------------------------------------------------------------------------: | ---------------------------------------:
|
||||||
BlazePose.Heavy | 53 ms | 38 ms
|
BlazePose GHUM Heavy | 53 ms | 38 ms
|
||||||
BlazePose.Full | 25 ms | 27 ms
|
BlazePose GHUM Full | 25 ms | 27 ms
|
||||||
BlazePose.Lite | 20 ms | 25 ms
|
BlazePose GHUM Lite | 20 ms | 25 ms
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
|
@ -237,7 +237,7 @@ pixel respectively. Please refer to the platform-specific usage examples below
|
||||||
for usage details.
|
for usage details.
|
||||||
|
|
||||||
*Fig 6. Example of MediaPipe Pose segmentation mask.* |
|
*Fig 6. Example of MediaPipe Pose segmentation mask.* |
|
||||||
:-----------------------------------------------------------: |
|
:---------------------------------------------------: |
|
||||||
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
|
<video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> |
|
||||||
|
|
||||||
### Python Solution API
|
### Python Solution API
|
||||||
|
|
|
@ -217,6 +217,7 @@ absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) {
|
||||||
header->format = video_header.format;
|
header->format = video_header.format;
|
||||||
header->width = video_header.width;
|
header->width = video_header.width;
|
||||||
header->height = video_header.height;
|
header->height = video_header.height;
|
||||||
|
header->duration = video_header.duration;
|
||||||
header->frame_rate = new_frame_rate;
|
header->frame_rate = new_frame_rate;
|
||||||
cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
|
cc->Outputs().Index(0).SetHeader(Adopt(header.release()));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -356,6 +356,57 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "landmarks_to_tensor_calculator_proto",
|
||||||
|
srcs = ["landmarks_to_tensor_calculator.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "landmarks_to_tensor_calculator",
|
||||||
|
srcs = ["landmarks_to_tensor_calculator.cc"],
|
||||||
|
hdrs = ["landmarks_to_tensor_calculator.h"],
|
||||||
|
copts = select({
|
||||||
|
"//mediapipe:apple": [
|
||||||
|
"-x objective-c++",
|
||||||
|
"-fobjc-arc", # enable reference-counting
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":landmarks_to_tensor_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "landmarks_to_tensor_calculator_test",
|
||||||
|
srcs = ["landmarks_to_tensor_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":landmarks_to_tensor_calculator",
|
||||||
|
":landmarks_to_tensor_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "tensors_to_floats_calculator_proto",
|
name = "tensors_to_floats_calculator_proto",
|
||||||
srcs = ["tensors_to_floats_calculator.proto"],
|
srcs = ["tensors_to_floats_calculator.proto"],
|
||||||
|
|
|
@ -99,13 +99,11 @@ class InferenceCalculator : public NodeIntf {
|
||||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||||
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
||||||
static constexpr SideInput<std::string>::Optional kNnApiDelegateCacheDir{
|
static constexpr SideInput<
|
||||||
"NNAPI_CACHE_DIR"};
|
mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{
|
||||||
static constexpr SideInput<std::string>::Optional kNnApiDelegateModelToken{
|
"DELEGATE"};
|
||||||
"NNAPI_MODEL_TOKEN"};
|
|
||||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
|
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
|
||||||
kOutTensors, kNnApiDelegateCacheDir,
|
kOutTensors, kDelegate);
|
||||||
kNnApiDelegateModelToken);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
using TfLiteDelegatePtr =
|
using TfLiteDelegatePtr =
|
||||||
|
|
|
@ -18,6 +18,9 @@ package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.calculator.proto";
|
||||||
|
option java_outer_classname = "InferenceCalculatorProto";
|
||||||
|
|
||||||
// Full Example:
|
// Full Example:
|
||||||
//
|
//
|
||||||
// node {
|
// node {
|
||||||
|
|
|
@ -50,11 +50,13 @@ int GetXnnpackDefaultNumThreads() {
|
||||||
// Returns number of threads to configure XNNPACK delegate with.
|
// Returns number of threads to configure XNNPACK delegate with.
|
||||||
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
// Returns user provided value if specified. Otherwise, tries to choose optimal
|
||||||
// number of threads depending on the device.
|
// number of threads depending on the device.
|
||||||
int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) {
|
int GetXnnpackNumThreads(
|
||||||
|
const bool opts_has_delegate,
|
||||||
|
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) {
|
||||||
static constexpr int kDefaultNumThreads = -1;
|
static constexpr int kDefaultNumThreads = -1;
|
||||||
if (opts.has_delegate() && opts.delegate().has_xnnpack() &&
|
if (opts_has_delegate && opts_delegate.has_xnnpack() &&
|
||||||
opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) {
|
opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) {
|
||||||
return opts.delegate().xnnpack().num_threads();
|
return opts_delegate.xnnpack().num_threads();
|
||||||
}
|
}
|
||||||
return GetXnnpackDefaultNumThreads();
|
return GetXnnpackDefaultNumThreads();
|
||||||
}
|
}
|
||||||
|
@ -175,33 +177,40 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
|
||||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
const auto& calculator_opts =
|
const auto& calculator_opts =
|
||||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||||
if (calculator_opts.has_delegate() &&
|
auto opts_delegate = calculator_opts.delegate();
|
||||||
calculator_opts.delegate().has_tflite()) {
|
if (!kDelegate(cc).IsEmpty()) {
|
||||||
|
mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
|
||||||
|
kDelegate(cc).Get();
|
||||||
|
CHECK(input_side_packet_delegate.has_tflite() ||
|
||||||
|
input_side_packet_delegate.has_xnnpack() ||
|
||||||
|
input_side_packet_delegate.has_nnapi() ||
|
||||||
|
input_side_packet_delegate.delegate_case() ==
|
||||||
|
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
|
||||||
|
<< "inference_calculator_cpu only supports delegate input side packet "
|
||||||
|
<< "for TFLite, XNNPack and Nnapi";
|
||||||
|
opts_delegate.MergeFrom(input_side_packet_delegate);
|
||||||
|
}
|
||||||
|
const bool opts_has_delegate =
|
||||||
|
calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||||
|
if (opts_has_delegate && opts_delegate.has_tflite()) {
|
||||||
// Default tflite inference requeqsted - no need to modify graph.
|
// Default tflite inference requeqsted - no need to modify graph.
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
const bool nnapi_requested = calculator_opts.has_delegate()
|
const bool nnapi_requested = opts_has_delegate ? opts_delegate.has_nnapi()
|
||||||
? calculator_opts.delegate().has_nnapi()
|
|
||||||
: calculator_opts.use_nnapi();
|
: calculator_opts.use_nnapi();
|
||||||
if (nnapi_requested) {
|
if (nnapi_requested) {
|
||||||
// Attempt to use NNAPI.
|
// Attempt to use NNAPI.
|
||||||
// If not supported, the default CPU delegate will be created and used.
|
// If not supported, the default CPU delegate will be created and used.
|
||||||
interpreter_->SetAllowFp16PrecisionForFp32(1);
|
interpreter_->SetAllowFp16PrecisionForFp32(1);
|
||||||
tflite::StatefulNnApiDelegate::Options options;
|
tflite::StatefulNnApiDelegate::Options options;
|
||||||
const auto& nnapi = calculator_opts.delegate().nnapi();
|
const auto& nnapi = opts_delegate.nnapi();
|
||||||
// Set up cache_dir and model_token for NNAPI compilation cache.
|
// Set up cache_dir and model_token for NNAPI compilation cache.
|
||||||
options.cache_dir =
|
options.cache_dir =
|
||||||
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
|
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
|
||||||
if (!kNnApiDelegateCacheDir(cc).IsEmpty()) {
|
|
||||||
options.cache_dir = kNnApiDelegateCacheDir(cc).Get().c_str();
|
|
||||||
}
|
|
||||||
options.model_token =
|
options.model_token =
|
||||||
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
|
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
|
||||||
if (!kNnApiDelegateModelToken(cc).IsEmpty()) {
|
|
||||||
options.model_token = kNnApiDelegateModelToken(cc).Get().c_str();
|
|
||||||
}
|
|
||||||
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
||||||
[](TfLiteDelegate*) {});
|
[](TfLiteDelegate*) {});
|
||||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||||
|
@ -213,13 +222,13 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
||||||
#if defined(__EMSCRIPTEN__)
|
#if defined(__EMSCRIPTEN__)
|
||||||
const bool use_xnnpack = true;
|
const bool use_xnnpack = true;
|
||||||
#else
|
#else
|
||||||
const bool use_xnnpack = calculator_opts.has_delegate() &&
|
const bool use_xnnpack = opts_has_delegate && opts_delegate.has_xnnpack();
|
||||||
calculator_opts.delegate().has_xnnpack();
|
|
||||||
#endif // defined(__EMSCRIPTEN__)
|
#endif // defined(__EMSCRIPTEN__)
|
||||||
|
|
||||||
if (use_xnnpack) {
|
if (use_xnnpack) {
|
||||||
TfLiteXNNPackDelegateOptions xnnpack_opts{};
|
TfLiteXNNPackDelegateOptions xnnpack_opts{};
|
||||||
xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts);
|
xnnpack_opts.num_threads =
|
||||||
|
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||||
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||||
&TfLiteXNNPackDelegateDelete);
|
&TfLiteXNNPackDelegateDelete);
|
||||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||||
|
|
|
@ -95,19 +95,30 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
||||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||||
use_advanced_gpu_api_ = options.has_delegate() &&
|
mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate();
|
||||||
options.delegate().has_gpu() &&
|
if (!kDelegate(cc).IsEmpty()) {
|
||||||
options.delegate().gpu().use_advanced_gpu_api();
|
mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
|
||||||
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
|
kDelegate(cc).Get();
|
||||||
tflite_gpu_runner_api_ = options.delegate().gpu().api();
|
CHECK(input_side_packet_delegate.has_gpu() ||
|
||||||
tflite_gpu_runner_usage_ = options.delegate().gpu().usage();
|
input_side_packet_delegate.delegate_case() ==
|
||||||
use_kernel_caching_ = use_advanced_gpu_api_ &&
|
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
|
||||||
options.delegate().gpu().has_cached_kernel_path();
|
<< "inference_calculator_gl only supports delegate input side packet "
|
||||||
|
<< "for Gpu";
|
||||||
|
delegate.MergeFrom(input_side_packet_delegate);
|
||||||
|
}
|
||||||
|
const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||||
|
use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() &&
|
||||||
|
delegate.gpu().use_advanced_gpu_api();
|
||||||
|
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
|
||||||
|
tflite_gpu_runner_api_ = delegate.gpu().api();
|
||||||
|
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
||||||
|
use_kernel_caching_ =
|
||||||
|
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
|
||||||
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
||||||
|
|
||||||
if (use_kernel_caching_) {
|
if (use_kernel_caching_) {
|
||||||
#ifdef MEDIAPIPE_ANDROID
|
#ifdef MEDIAPIPE_ANDROID
|
||||||
cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() +
|
cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
|
||||||
mediapipe::File::Basename(options.model_path()) +
|
mediapipe::File::Basename(options.model_path()) +
|
||||||
".ker";
|
".ker";
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // MEDIAPIPE_ANDROID
|
||||||
|
|
101
mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc
Normal file
101
mediapipe/calculators/tensor/landmarks_to_tensor_calculator.cc
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
float GetAttribute(
|
||||||
|
const Landmark& landmark,
|
||||||
|
const LandmarksToTensorCalculatorOptions::Attribute& attribute) {
|
||||||
|
switch (attribute) {
|
||||||
|
case LandmarksToTensorCalculatorOptions::X:
|
||||||
|
return landmark.x();
|
||||||
|
case LandmarksToTensorCalculatorOptions::Y:
|
||||||
|
return landmark.y();
|
||||||
|
case LandmarksToTensorCalculatorOptions::Z:
|
||||||
|
return landmark.z();
|
||||||
|
case LandmarksToTensorCalculatorOptions::VISIBILITY:
|
||||||
|
return landmark.visibility();
|
||||||
|
case LandmarksToTensorCalculatorOptions::PRESENCE:
|
||||||
|
return landmark.presence();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class LandmarksToTensorCalculatorImpl
|
||||||
|
: public NodeImpl<LandmarksToTensorCalculator> {
|
||||||
|
public:
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
options_ = cc->Options<LandmarksToTensorCalculatorOptions>();
|
||||||
|
RET_CHECK(options_.attributes_size() > 0)
|
||||||
|
<< "At least one attribute must be specified";
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
if (kInLandmarkList(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get input landmarks.
|
||||||
|
const auto& in_landmarks = *kInLandmarkList(cc);
|
||||||
|
|
||||||
|
// Determine tensor shape.
|
||||||
|
const int n_landmarks = in_landmarks.landmark_size();
|
||||||
|
const int n_attributes = options_.attributes_size();
|
||||||
|
auto tensor_shape = options_.flatten()
|
||||||
|
? Tensor::Shape{1, n_landmarks * n_attributes}
|
||||||
|
: Tensor::Shape{1, n_landmarks, n_attributes};
|
||||||
|
|
||||||
|
// Create empty tesnor.
|
||||||
|
Tensor tensor(Tensor::ElementType::kFloat32, tensor_shape);
|
||||||
|
auto* buffer = tensor.GetCpuWriteView().buffer<float>();
|
||||||
|
|
||||||
|
// Fill tensor with landmark attributes.
|
||||||
|
for (int i = 0; i < n_landmarks; ++i) {
|
||||||
|
for (int j = 0; j < n_attributes; ++j) {
|
||||||
|
buffer[i * n_attributes + j] =
|
||||||
|
GetAttribute(in_landmarks.landmark(i), options_.attributes(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return vector with a single tensor.
|
||||||
|
auto result = std::vector<Tensor>();
|
||||||
|
result.push_back(std::move(tensor));
|
||||||
|
kOutTensors(cc).Send(std::move(result));
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
LandmarksToTensorCalculatorOptions options_;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksToTensorCalculatorImpl);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_
|
||||||
|
#define MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
// A calculator for converting landmars into a Tensor.
|
||||||
|
//
|
||||||
|
// Input:
|
||||||
|
// LANDMARKS - LandmarkList
|
||||||
|
// Landmarks to be converted into a Tensor.
|
||||||
|
//
|
||||||
|
// Output:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// Vector containing a single Tensor populated with landmark values.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "LandmarksToTensorCalculator"
|
||||||
|
// input_stream: "LANDMARKS:landmarks"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// options: {
|
||||||
|
// [mediapipe.LandmarksToTensorCalculatorOptions.ext] {
|
||||||
|
// attributes: [X, Y, Z, VISIBILITY, PRESENCE]
|
||||||
|
// # flatten: true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
class LandmarksToTensorCalculator : public NodeIntf {
|
||||||
|
public:
|
||||||
|
static constexpr Input<LandmarkList>::Optional kInLandmarkList{"LANDMARKS"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
||||||
|
MEDIAPIPE_NODE_INTERFACE(LandmarksToTensorCalculator, kInLandmarkList,
|
||||||
|
kOutTensors);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_CALCULATORS_LANDMARKS_TO_TENSOR_CALCULATOR_H_
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// The option proto for the LandmarksToTensorCalculator.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message LandmarksToTensorCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional LandmarksToTensorCalculatorOptions ext = 394810235;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Attribute {
|
||||||
|
X = 0;
|
||||||
|
Y = 1;
|
||||||
|
Z = 2;
|
||||||
|
VISIBILITY = 3;
|
||||||
|
PRESENCE = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subset and order of attributes as they should appear in the output Tensor.
|
||||||
|
// Should contain at least one attribute.
|
||||||
|
repeated Attribute attributes = 1;
|
||||||
|
|
||||||
|
// Collapses all landmark attributes into a one dimensional tensor (i.e.
|
||||||
|
// switches from (n_landmarks, n_attributes) to (n_landmarks * n_attributes)
|
||||||
|
// representation).
|
||||||
|
optional bool flatten = 2 [default = false];
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "mediapipe/calculators/tensor/landmarks_to_tensor_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::ParseTextProtoOrDie;
|
||||||
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
|
|
||||||
|
void RunLandmarks(mediapipe::CalculatorRunner* runner,
|
||||||
|
const LandmarkList& landmarks) {
|
||||||
|
runner->MutableInputs()
|
||||||
|
->Tag("LANDMARKS")
|
||||||
|
.packets.push_back(MakePacket<LandmarkList>(landmarks).At(Timestamp(0)));
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
}
|
||||||
|
|
||||||
|
const Tensor& GetOutputTensor(mediapipe::CalculatorRunner* runner) {
|
||||||
|
const auto& output_packets = runner->Outputs().Tag("TENSORS").packets;
|
||||||
|
EXPECT_EQ(output_packets.size(), 1);
|
||||||
|
|
||||||
|
const auto& tensors = output_packets[0].Get<std::vector<Tensor>>();
|
||||||
|
EXPECT_EQ(tensors.size(), 1);
|
||||||
|
|
||||||
|
return tensors[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
void ValidateTensor(const Tensor& tensor,
|
||||||
|
const std::vector<int>& expected_shape,
|
||||||
|
const std::vector<float>& expected_values) {
|
||||||
|
EXPECT_EQ(tensor.shape().dims, expected_shape);
|
||||||
|
EXPECT_EQ(tensor.shape().num_elements(), expected_values.size());
|
||||||
|
|
||||||
|
auto* tensor_buffer = tensor.GetCpuReadView().buffer<float>();
|
||||||
|
const std::vector<float> tensor_values(
|
||||||
|
tensor_buffer, tensor_buffer + tensor.shape().num_elements());
|
||||||
|
EXPECT_THAT(tensor_values, testing::ElementsAreArray(expected_values));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarksToTensorCalculatorTest, AllAttributes) {
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "LandmarksToTensorCalculator"
|
||||||
|
input_stream: "LANDMARKS:landmarks"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
options: {
|
||||||
|
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
|
||||||
|
attributes: [ X, Y, Z, VISIBILITY, PRESENCE ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
LandmarkList landmarks;
|
||||||
|
auto* landmark1 = landmarks.add_landmark();
|
||||||
|
landmark1->set_x(1.0f);
|
||||||
|
landmark1->set_y(2.0f);
|
||||||
|
landmark1->set_z(3.0f);
|
||||||
|
landmark1->set_visibility(4.0f);
|
||||||
|
landmark1->set_presence(5.0f);
|
||||||
|
auto* landmark2 = landmarks.add_landmark();
|
||||||
|
landmark2->set_x(6.0f);
|
||||||
|
landmark2->set_y(7.0f);
|
||||||
|
landmark2->set_z(8.0f);
|
||||||
|
landmark2->set_visibility(9.0f);
|
||||||
|
landmark2->set_presence(10.0f);
|
||||||
|
|
||||||
|
RunLandmarks(&runner, landmarks);
|
||||||
|
const auto& tensor = GetOutputTensor(&runner);
|
||||||
|
ValidateTensor(tensor, /*expected_shape=*/{1, 2, 5}, /*expected_values=*/
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarksToTensorCalculatorTest, XYZAttributes) {
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "LandmarksToTensorCalculator"
|
||||||
|
input_stream: "LANDMARKS:landmarks"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
options: {
|
||||||
|
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
|
||||||
|
attributes: [ X, Y, Z ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
LandmarkList landmarks;
|
||||||
|
auto* landmark1 = landmarks.add_landmark();
|
||||||
|
landmark1->set_x(1.0f);
|
||||||
|
landmark1->set_y(2.0f);
|
||||||
|
landmark1->set_z(3.0f);
|
||||||
|
auto* landmark2 = landmarks.add_landmark();
|
||||||
|
landmark2->set_x(6.0f);
|
||||||
|
landmark2->set_y(7.0f);
|
||||||
|
landmark2->set_z(8.0f);
|
||||||
|
|
||||||
|
RunLandmarks(&runner, landmarks);
|
||||||
|
const auto& tensor = GetOutputTensor(&runner);
|
||||||
|
ValidateTensor(tensor, /*expected_shape=*/{1, 2, 3}, /*expected_values=*/
|
||||||
|
{1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LandmarksToTensorCalculatorTest, XYZAttributes_Flatten) {
|
||||||
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "LandmarksToTensorCalculator"
|
||||||
|
input_stream: "LANDMARKS:landmarks"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
options: {
|
||||||
|
[mediapipe.LandmarksToTensorCalculatorOptions.ext] {
|
||||||
|
attributes: [ X, Y, Z ]
|
||||||
|
flatten: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
LandmarkList landmarks;
|
||||||
|
auto* landmark1 = landmarks.add_landmark();
|
||||||
|
landmark1->set_x(1.0f);
|
||||||
|
landmark1->set_y(2.0f);
|
||||||
|
landmark1->set_z(3.0f);
|
||||||
|
auto* landmark2 = landmarks.add_landmark();
|
||||||
|
landmark2->set_x(6.0f);
|
||||||
|
landmark2->set_y(7.0f);
|
||||||
|
landmark2->set_z(8.0f);
|
||||||
|
|
||||||
|
RunLandmarks(&runner, landmarks);
|
||||||
|
const auto& tensor = GetOutputTensor(&runner);
|
||||||
|
ValidateTensor(tensor, /*expected_shape=*/{1, 6}, /*expected_values=*/
|
||||||
|
{1.0f, 2.0f, 3.0f, 6.0f, 7.0f, 8.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -57,6 +57,16 @@ mediapipe_proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "filter_detections_calculator_proto",
|
||||||
|
srcs = ["filter_detections_calculator.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "timed_box_list_id_to_label_calculator_proto",
|
name = "timed_box_list_id_to_label_calculator_proto",
|
||||||
srcs = ["timed_box_list_id_to_label_calculator.proto"],
|
srcs = ["timed_box_list_id_to_label_calculator.proto"],
|
||||||
|
@ -158,6 +168,21 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "filter_detections_calculator_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["filter_detections_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":filter_detections_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/deps:message_matchers",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "packet_latency_calculator",
|
name = "packet_latency_calculator",
|
||||||
srcs = ["packet_latency_calculator.cc"],
|
srcs = ["packet_latency_calculator.cc"],
|
||||||
|
@ -372,6 +397,20 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "filter_detections_calculator",
|
||||||
|
srcs = ["filter_detections_calculator.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":filter_detections_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "landmarks_to_detection_calculator",
|
name = "landmarks_to_detection_calculator",
|
||||||
srcs = ["landmarks_to_detection_calculator.cc"],
|
srcs = ["landmarks_to_detection_calculator.cc"],
|
||||||
|
|
81
mediapipe/calculators/util/filter_detections_calculator.cc
Normal file
81
mediapipe/calculators/util/filter_detections_calculator.cc
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "mediapipe/calculators/util/filter_detections_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
const char kInputDetectionsTag[] = "INPUT_DETECTIONS";
|
||||||
|
const char kOutputDetectionsTag[] = "OUTPUT_DETECTIONS";
|
||||||
|
|
||||||
|
//
|
||||||
|
// Calculator to filter out detections that do not meet the criteria specified
|
||||||
|
// in options.
|
||||||
|
//
|
||||||
|
class FilterDetectionsCalculator : public CalculatorBase {
|
||||||
|
public:
|
||||||
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
|
RET_CHECK(cc->Inputs().HasTag(kInputDetectionsTag));
|
||||||
|
RET_CHECK(cc->Outputs().HasTag(kOutputDetectionsTag));
|
||||||
|
|
||||||
|
cc->Inputs().Tag(kInputDetectionsTag).Set<std::vector<Detection>>();
|
||||||
|
cc->Outputs().Tag(kOutputDetectionsTag).Set<std::vector<Detection>>();
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
options_ = cc->Options<mediapipe::FilterDetectionsCalculatorOptions>();
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
const auto& input_detections =
|
||||||
|
cc->Inputs().Tag(kInputDetectionsTag).Get<std::vector<Detection>>();
|
||||||
|
|
||||||
|
auto output_detections = absl::make_unique<std::vector<Detection>>();
|
||||||
|
|
||||||
|
for (const Detection& detection : input_detections) {
|
||||||
|
RET_CHECK_GT(detection.score_size(), 0);
|
||||||
|
// Note: only score at index 0 supported.
|
||||||
|
if (detection.score(0) >= options_.min_score()) {
|
||||||
|
output_detections->push_back(detection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kOutputDetectionsTag)
|
||||||
|
.Add(output_detections.release(), cc->InputTimestamp());
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mediapipe::FilterDetectionsCalculatorOptions options_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_CALCULATOR(FilterDetectionsCalculator);
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message FilterDetectionsCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional FilterDetectionsCalculatorOptions ext = 395478132;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detections lower than this score get filtered out.
|
||||||
|
optional float min_score = 1;
|
||||||
|
}
|
100
mediapipe/calculators/util/filter_detections_calculator_test.cc
Normal file
100
mediapipe/calculators/util/filter_detections_calculator_test.cc
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/deps/message_matchers.h"
|
||||||
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
|
||||||
|
absl::Status RunGraph(std::vector<Detection>& input_detections,
|
||||||
|
std::vector<Detection>* output_detections) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FilterDetectionsCalculator"
|
||||||
|
input_stream: "INPUT_DETECTIONS:input_detections"
|
||||||
|
output_stream: "OUTPUT_DETECTIONS:output_detections"
|
||||||
|
options {
|
||||||
|
[mediapipe.FilterDetectionsCalculatorOptions.ext] { min_score: 0.5 }
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
const Timestamp input_timestamp = Timestamp(0);
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("INPUT_DETECTIONS")
|
||||||
|
.packets.push_back(MakePacket<std::vector<Detection>>(input_detections)
|
||||||
|
.At(input_timestamp));
|
||||||
|
MP_RETURN_IF_ERROR(runner.Run()) << "Calculator run failed.";
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner.Outputs().Tag("OUTPUT_DETECTIONS").packets;
|
||||||
|
RET_CHECK_EQ(output_packets.size(), 1);
|
||||||
|
|
||||||
|
*output_detections = output_packets[0].Get<std::vector<Detection>>();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetections) {
|
||||||
|
std::vector<Detection> input_detections;
|
||||||
|
Detection d1, d2;
|
||||||
|
d1.add_score(0.2);
|
||||||
|
d2.add_score(0.8);
|
||||||
|
input_detections.push_back(d1);
|
||||||
|
input_detections.push_back(d2);
|
||||||
|
|
||||||
|
std::vector<Detection> output_detections;
|
||||||
|
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
|
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d2)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsMultiple) {
|
||||||
|
std::vector<Detection> input_detections;
|
||||||
|
Detection d1, d2, d3, d4;
|
||||||
|
d1.add_score(0.3);
|
||||||
|
d2.add_score(0.4);
|
||||||
|
d3.add_score(0.5);
|
||||||
|
d4.add_score(0.6);
|
||||||
|
input_detections.push_back(d1);
|
||||||
|
input_detections.push_back(d2);
|
||||||
|
input_detections.push_back(d3);
|
||||||
|
input_detections.push_back(d4);
|
||||||
|
|
||||||
|
std::vector<Detection> output_detections;
|
||||||
|
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
|
EXPECT_THAT(output_detections, ElementsAre(mediapipe::EqualsProto(d3),
|
||||||
|
mediapipe::EqualsProto(d4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionsCalculatorTest, TestFilterDetectionsEmpty) {
|
||||||
|
std::vector<Detection> input_detections;
|
||||||
|
|
||||||
|
std::vector<Detection> output_detections;
|
||||||
|
MP_EXPECT_OK(RunGraph(input_detections, &output_detections));
|
||||||
|
|
||||||
|
EXPECT_EQ(output_detections.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -25,7 +25,7 @@ node: {
|
||||||
output_stream: "input_video_cpu"
|
output_stream: "input_video_cpu"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Transforms the input image on CPU to a 480x640 image.
|
# Scale the image's longer side to 640, keeping aspect ratio.
|
||||||
node: {
|
node: {
|
||||||
calculator: "ImageTransformationCalculator"
|
calculator: "ImageTransformationCalculator"
|
||||||
input_stream: "IMAGE:input_video_cpu"
|
input_stream: "IMAGE:input_video_cpu"
|
||||||
|
|
|
@ -48,5 +48,14 @@ std::string ClassRegistry::GetMethodName(std::string cls, std::string method) {
|
||||||
return method;
|
return method;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ClassRegistry::GetFieldName(std::string cls, std::string field) {
|
||||||
|
std::string key = absl::StrFormat("%s##%s", cls, field);
|
||||||
|
auto match = renaming_map_.find(key);
|
||||||
|
if (match != renaming_map_.end()) {
|
||||||
|
return match->second;
|
||||||
|
}
|
||||||
|
return field;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace android
|
} // namespace android
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -33,6 +33,7 @@ class ClassRegistry {
|
||||||
absl::node_hash_map<std::string, std::string> renaming_map);
|
absl::node_hash_map<std::string, std::string> renaming_map);
|
||||||
std::string GetClassName(std::string cls);
|
std::string GetClassName(std::string cls);
|
||||||
std::string GetMethodName(std::string cls, std::string method);
|
std::string GetMethodName(std::string cls, std::string method);
|
||||||
|
std::string GetFieldName(std::string cls, std::string field);
|
||||||
|
|
||||||
// TODO: Just have the prefix instead of all these constants.
|
// TODO: Just have the prefix instead of all these constants.
|
||||||
static constexpr char const* kAndroidAssetUtilClassName =
|
static constexpr char const* kAndroidAssetUtilClassName =
|
||||||
|
@ -59,6 +60,8 @@ class ClassRegistry {
|
||||||
"com/google/mediapipe/framework/PacketGetter";
|
"com/google/mediapipe/framework/PacketGetter";
|
||||||
static constexpr char const* kPacketWithHeaderCallbackClassName =
|
static constexpr char const* kPacketWithHeaderCallbackClassName =
|
||||||
"com/google/mediapipe/framework/PacketWithHeaderCallback";
|
"com/google/mediapipe/framework/PacketWithHeaderCallback";
|
||||||
|
static constexpr char const* kProtoUtilSerializedMessageClassName =
|
||||||
|
"com/google/mediapipe/framework/ProtoUtil$SerializedMessage";
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ClassRegistry();
|
ClassRegistry();
|
||||||
|
|
|
@ -156,10 +156,20 @@ bool ThrowIfError(JNIEnv* env, absl::Status status) {
|
||||||
}
|
}
|
||||||
|
|
||||||
SerializedMessageIds::SerializedMessageIds(JNIEnv* env, jobject data) {
|
SerializedMessageIds::SerializedMessageIds(JNIEnv* env, jobject data) {
|
||||||
jclass j_class = reinterpret_cast<jclass>(env->NewGlobalRef(env->FindClass(
|
auto& class_registry = mediapipe::android::ClassRegistry::GetInstance();
|
||||||
"com/google/mediapipe/framework/ProtoUtil$SerializedMessage")));
|
std::string serialized_message(
|
||||||
type_name_id = env->GetFieldID(j_class, "typeName", "Ljava/lang/String;");
|
mediapipe::android::ClassRegistry::kProtoUtilSerializedMessageClassName);
|
||||||
value_id = env->GetFieldID(j_class, "value", "[B");
|
std::string serialized_message_obfuscated =
|
||||||
|
class_registry.GetClassName(serialized_message);
|
||||||
|
std::string type_name_obfuscated =
|
||||||
|
class_registry.GetFieldName(serialized_message, "typeName");
|
||||||
|
std::string value_obfuscated =
|
||||||
|
class_registry.GetFieldName(serialized_message, "value");
|
||||||
|
jclass j_class = reinterpret_cast<jclass>(
|
||||||
|
env->NewGlobalRef(env->FindClass(serialized_message_obfuscated.c_str())));
|
||||||
|
type_name_id = env->GetFieldID(j_class, type_name_obfuscated.c_str(),
|
||||||
|
"Ljava/lang/String;");
|
||||||
|
value_id = env->GetFieldID(j_class, value_obfuscated.c_str(), "[B");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace android
|
} // namespace android
|
||||||
|
|
|
@ -225,6 +225,12 @@ void RegisterPacketCreatorNatives(JNIEnv *env) {
|
||||||
AddJNINativeMethod(&packet_creator_methods, packet_creator,
|
AddJNINativeMethod(&packet_creator_methods, packet_creator,
|
||||||
"nativeCreateString", "(JLjava/lang/String;)J",
|
"nativeCreateString", "(JLjava/lang/String;)J",
|
||||||
(void *)&PACKET_CREATOR_METHOD(nativeCreateString));
|
(void *)&PACKET_CREATOR_METHOD(nativeCreateString));
|
||||||
|
std::string serialized_message_name = class_registry.GetClassName(
|
||||||
|
mediapipe::android::ClassRegistry::kProtoUtilSerializedMessageClassName);
|
||||||
|
AddJNINativeMethod(&packet_creator_methods, packet_creator,
|
||||||
|
"nativeCreateProto",
|
||||||
|
"(JL" + serialized_message_name + ";)J",
|
||||||
|
(void *)&PACKET_CREATOR_METHOD(nativeCreateProto));
|
||||||
RegisterNativesVector(env, packet_creator_class, packet_creator_methods);
|
RegisterNativesVector(env, packet_creator_class, packet_creator_methods);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -119,6 +119,18 @@ public class ImageSolutionBase extends SolutionBase {
|
||||||
"Receving a frame with invalid timestamp."));
|
"Receving a frame with invalid timestamp."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (!solutionGraphStarted.get()) {
|
||||||
|
if (imageObj instanceof TextureFrame) {
|
||||||
|
((TextureFrame) imageObj).release();
|
||||||
|
}
|
||||||
|
throwException(
|
||||||
|
"The solution graph hasn't been successfully started or error occurs during graph"
|
||||||
|
+ " initializaton.",
|
||||||
|
new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Graph is not started."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
lastTimestamp = timestamp;
|
lastTimestamp = timestamp;
|
||||||
Packet imagePacket = null;
|
Packet imagePacket = null;
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -87,6 +87,7 @@ public class SolutionBase {
|
||||||
} else {
|
} else {
|
||||||
Log.e(TAG, message, e);
|
Log.e(TAG, message, e);
|
||||||
}
|
}
|
||||||
|
throw e;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
BIN
mediapipe/modules/hand_landmark/hand_landmark.tflite
Normal file → Executable file
BIN
mediapipe/modules/hand_landmark/hand_landmark.tflite
Normal file → Executable file
Binary file not shown.
|
@ -92,11 +92,11 @@ class SolveEpnpTest : public Test {
|
||||||
const float scale = output_3d_points[0].z() / expected_3d_points_[0].z();
|
const float scale = output_3d_points[0].z() / expected_3d_points_[0].z();
|
||||||
for (int i = 0; i < kNumKeypoints; ++i) {
|
for (int i = 0; i < kNumKeypoints; ++i) {
|
||||||
EXPECT_NEAR(output_3d_points[i].x(), expected_3d_points_[i].x() * scale,
|
EXPECT_NEAR(output_3d_points[i].x(), expected_3d_points_[i].x() * scale,
|
||||||
1.e-6f);
|
2.e-6f);
|
||||||
EXPECT_NEAR(output_3d_points[i].y(), expected_3d_points_[i].y() * scale,
|
EXPECT_NEAR(output_3d_points[i].y(), expected_3d_points_[i].y() * scale,
|
||||||
1.e-6f);
|
2.e-6f);
|
||||||
EXPECT_NEAR(output_3d_points[i].z(), expected_3d_points_[i].z() * scale,
|
EXPECT_NEAR(output_3d_points[i].z(), expected_3d_points_[i].z() * scale,
|
||||||
1.e-6f);
|
2.e-6f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,555 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
using ::tflite::gpu::BHWC;
|
||||||
|
using ::tflite::gpu::float2;
|
||||||
|
using ::tflite::gpu::float3;
|
||||||
|
using ::tflite::gpu::int2;
|
||||||
|
using ::tflite::gpu::int3;
|
||||||
|
using ::tflite::gpu::LandmarksToTransformMatrixV1Attributes;
|
||||||
|
using ::tflite::gpu::LandmarksToTransformMatrixV2Attributes;
|
||||||
|
|
||||||
|
using ::tflite::GetInput;
|
||||||
|
using ::tflite::GetOutput;
|
||||||
|
using ::tflite::GetTensorData;
|
||||||
|
using ::tflite::GetTensorShape;
|
||||||
|
using ::tflite::NumDimensions;
|
||||||
|
using ::tflite::NumInputs;
|
||||||
|
using ::tflite::NumOutputs;
|
||||||
|
using ::tflite::RuntimeShape;
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kDataInputTensor = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
constexpr int3 kTensformMatrixShape(1, 4, 4);
|
||||||
|
|
||||||
|
float2 Read3DLandmarkXY(const float* data, int idx) {
|
||||||
|
float2 result;
|
||||||
|
result.x = data[idx * 3];
|
||||||
|
result.y = data[idx * 3 + 1];
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
float3 Read3DLandmarkXYZ(const float* data, int idx) {
|
||||||
|
float3 result;
|
||||||
|
result.x = data[idx * 3];
|
||||||
|
result.y = data[idx * 3 + 1];
|
||||||
|
result.z = data[idx * 3 + 2];
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Mat3 {
|
||||||
|
Mat3() { data.resize(9); }
|
||||||
|
Mat3(float x00, float x01, float x02, float x10, float x11, float x12,
|
||||||
|
float x20, float x21, float x22)
|
||||||
|
: data{x00, x01, x02, x10, x11, x12, x20, x21, x22} {}
|
||||||
|
|
||||||
|
Mat3 operator*(const Mat3& other) {
|
||||||
|
Mat3 result;
|
||||||
|
for (int r = 0; r < 3; r++) {
|
||||||
|
for (int c = 0; c < 3; c++) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int k = 0; k < 3; k++) {
|
||||||
|
sum += this->Get(r, k) * other.Get(k, c);
|
||||||
|
}
|
||||||
|
result.Set(r, c, sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
float3 operator*(const float3& vec) const {
|
||||||
|
float3 result;
|
||||||
|
for (int r = 0; r < 3; r++) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int k = 0; k < 3; k++) {
|
||||||
|
sum += this->Get(r, k) * vec[k];
|
||||||
|
}
|
||||||
|
result[r] = sum;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
float Get(int x, int y) const { return data[x * 3 + y]; }
|
||||||
|
void Set(int x, int y, float val) { data[x * 3 + y] = val; }
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Mat4 {
|
||||||
|
Mat4() { data.resize(16); }
|
||||||
|
Mat4(float x00, float x01, float x02, float x03, float x10, float x11,
|
||||||
|
float x12, float x13, float x20, float x21, float x22, float x23,
|
||||||
|
float x30, float x31, float x32, float x33)
|
||||||
|
: data{x00, x01, x02, x03, x10, x11, x12, x13,
|
||||||
|
x20, x21, x22, x23, x30, x31, x32, x33} {}
|
||||||
|
void operator*=(const Mat4& other) {
|
||||||
|
Mat4 result;
|
||||||
|
for (int r = 0; r < 4; r++) {
|
||||||
|
for (int c = 0; c < 4; c++) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int k = 0; k < 4; k++) {
|
||||||
|
sum += this->Get(r, k) * other.Get(k, c);
|
||||||
|
}
|
||||||
|
result.Set(r, c, sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::memcpy(this->data.data(), result.data.data(),
|
||||||
|
result.data.size() * sizeof(float));
|
||||||
|
}
|
||||||
|
float Get(int x, int y) const { return data[x * 4 + y]; }
|
||||||
|
void Set(int x, int y, float val) { data[x * 4 + y] = val; }
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace v1 {
|
||||||
|
|
||||||
|
inline void LandmarksToTransformMatrixV1(
|
||||||
|
const LandmarksToTransformMatrixV1Attributes& params,
|
||||||
|
const RuntimeShape& input0_shape, const float* landmarks,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 3);
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.Dims(0), 1);
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.Dims(1), 1);
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.Dims(2), 1);
|
||||||
|
|
||||||
|
float2 left_landmark = Read3DLandmarkXY(landmarks, params.left_rotation_idx);
|
||||||
|
float2 right_landmark =
|
||||||
|
Read3DLandmarkXY(landmarks, params.right_rotation_idx);
|
||||||
|
|
||||||
|
float alpha = -std::atan((right_landmark.y - left_landmark.y) /
|
||||||
|
(right_landmark.x - left_landmark.x));
|
||||||
|
|
||||||
|
float2 max_value(-100000, -100000);
|
||||||
|
float2 min_value(100000, 100000);
|
||||||
|
for (int i = 0; i < params.subset.size(); i++) {
|
||||||
|
for (int j = 0; j < 2; j++) {
|
||||||
|
float2 landmark_current =
|
||||||
|
Read3DLandmarkXY(landmarks, params.subset[i][j]);
|
||||||
|
float2 rotated(
|
||||||
|
landmark_current.x * cos(alpha) - landmark_current.y * sin(alpha),
|
||||||
|
landmark_current.x * sin(alpha) + landmark_current.y * cos(alpha));
|
||||||
|
max_value = float2(std::max(max_value.x, rotated.x),
|
||||||
|
std::max(max_value.y, rotated.y));
|
||||||
|
min_value = float2(std::min(min_value.x, rotated.x),
|
||||||
|
std::min(min_value.y, rotated.y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float2 bbox_size((max_value.x - min_value.x) * params.bbox_size_multiplier,
|
||||||
|
(max_value.y - min_value.y) * params.bbox_size_multiplier);
|
||||||
|
|
||||||
|
Mat3 scale_matrix(
|
||||||
|
bbox_size.x / params.landmarks_range, 0.0, 0.0, // first row
|
||||||
|
0.0, bbox_size.y / params.landmarks_range, 0.0, // second row
|
||||||
|
0.0, 0.0, 1.0); // third row
|
||||||
|
|
||||||
|
float2 middle((max_value.x + min_value.x) / 2.0,
|
||||||
|
(max_value.y + min_value.y) / 2.0);
|
||||||
|
|
||||||
|
float2 rotated_middle(middle.x * cos(-alpha) - middle.y * sin(-alpha),
|
||||||
|
middle.x * sin(-alpha) + middle.y * cos(-alpha));
|
||||||
|
|
||||||
|
Mat3 rotation_matrix(
|
||||||
|
cos(-alpha), -sin(-alpha),
|
||||||
|
(rotated_middle.x / params.landmarks_range) * 2.0 - 1.0, // first row
|
||||||
|
sin(-alpha), cos(-alpha),
|
||||||
|
(rotated_middle.y / params.landmarks_range) * 2.0 - 1.0, // second row
|
||||||
|
0, 0, 1); // third row
|
||||||
|
|
||||||
|
Mat3 to_relative(2.0 / (params.output_hw.w - 1.0), 0.0, -1.0, // first row
|
||||||
|
0.0, 2.0 / (params.output_hw.h - 1.0), -1.0, // second row
|
||||||
|
0.0, 0.0, 1.0); // third row
|
||||||
|
|
||||||
|
Mat3 to_absolute((params.input_hw.w - 1.0) / 2.0, 0.0,
|
||||||
|
(params.input_hw.w - 1.0) / 2.0, // first row
|
||||||
|
0.0, (params.input_hw.h - 1.0) / 2.0,
|
||||||
|
(params.input_hw.h - 1.0) / 2.0, // second row
|
||||||
|
0.0, 0.0, 1.0); // third row
|
||||||
|
|
||||||
|
// Inverse Transformstion Matrix
|
||||||
|
Mat3 itm = to_absolute * rotation_matrix * scale_matrix * to_relative;
|
||||||
|
|
||||||
|
output_data[0] = itm.Get(0, 0);
|
||||||
|
output_data[1] = itm.Get(0, 1);
|
||||||
|
output_data[2] = 0.0;
|
||||||
|
output_data[3] = itm.Get(0, 2);
|
||||||
|
|
||||||
|
output_data[4] = itm.Get(1, 0);
|
||||||
|
output_data[5] = itm.Get(1, 1);
|
||||||
|
output_data[6] = 0.0;
|
||||||
|
output_data[7] = itm.Get(1, 2);
|
||||||
|
|
||||||
|
output_data[8] = itm.Get(2, 0);
|
||||||
|
output_data[9] = itm.Get(2, 1);
|
||||||
|
output_data[10] = itm.Get(2, 2);
|
||||||
|
output_data[11] = 0.0;
|
||||||
|
|
||||||
|
output_data[12] = 0.0;
|
||||||
|
output_data[13] = 0.0;
|
||||||
|
output_data[14] = 0.0;
|
||||||
|
output_data[15] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = kTensformMatrixShape.x;
|
||||||
|
output_size->data[1] = kTensformMatrixShape.y;
|
||||||
|
output_size->data[2] = kTensformMatrixShape.z;
|
||||||
|
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
LandmarksToTransformMatrixV1Attributes op_params;
|
||||||
|
BHWC output_shape;
|
||||||
|
auto status = tflite::gpu::ParseLandmarksToTransformMatrixV1Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.bbox_size_multiplier == 0) {
|
||||||
|
context->ReportError(context, "Incorrect bbox_size_multiplier: %d",
|
||||||
|
op_params.bbox_size_multiplier);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.dimensions != 3) {
|
||||||
|
context->ReportError(context, "Incorrect dimensions: %d",
|
||||||
|
op_params.dimensions);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.input_hw.h <= 0 || op_params.input_hw.w <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect input_hw: h = %d w = %d",
|
||||||
|
op_params.input_hw.h, op_params.input_hw.w);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.output_hw.h <= 0 || op_params.output_hw.w <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect output_hw: h = %d w = %d",
|
||||||
|
op_params.output_hw.h, op_params.output_hw.w);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.landmarks_range <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect landmarks_range: %d",
|
||||||
|
op_params.landmarks_range);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.left_rotation_idx < 0) {
|
||||||
|
context->ReportError(context, "Incorrect left_rotation_idx: %d",
|
||||||
|
op_params.left_rotation_idx);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.right_rotation_idx < 0) {
|
||||||
|
context->ReportError(context, "Incorrect right_rotation_idx: %d",
|
||||||
|
op_params.right_rotation_idx);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.subset.empty()) {
|
||||||
|
context->ReportError(context, "Subset parameter is empty");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
int counter = 0;
|
||||||
|
for (auto& val : op_params.subset) {
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
if (val[i] < 0) {
|
||||||
|
context->ReportError(context,
|
||||||
|
"Incorrect subset value: index = %d, value = %d",
|
||||||
|
counter, val[i]);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
counter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 = GetInput(context, node, kDataInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
LandmarksToTransformMatrixV1(
|
||||||
|
op_params, GetTensorShape(input0), GetTensorData<float>(input0),
|
||||||
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace v1
|
||||||
|
|
||||||
|
namespace v2 {
|
||||||
|
|
||||||
|
void EstimateRotationRadians(const float* input_data_0, int left_rotation_idx,
|
||||||
|
int right_rotation_idx,
|
||||||
|
float target_rotation_radians,
|
||||||
|
float* rotation_radians) {
|
||||||
|
const float3 left_landmark =
|
||||||
|
Read3DLandmarkXYZ(input_data_0, left_rotation_idx);
|
||||||
|
const float3 right_landmark =
|
||||||
|
Read3DLandmarkXYZ(input_data_0, right_rotation_idx);
|
||||||
|
const float left_x = left_landmark[0];
|
||||||
|
const float left_y = left_landmark[1];
|
||||||
|
const float right_x = right_landmark[0];
|
||||||
|
const float right_y = right_landmark[1];
|
||||||
|
float rotation = std::atan2(right_y - left_y, right_x - left_x);
|
||||||
|
rotation = target_rotation_radians - rotation;
|
||||||
|
*rotation_radians = rotation;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EstimateCenterAndSize(const float* input_data_0,
|
||||||
|
std::vector<tflite::gpu::int2> subset_idxs,
|
||||||
|
float rotation_radians, float* crop_x, float* crop_y,
|
||||||
|
float* crop_width, float* crop_height) {
|
||||||
|
std::vector<float3> landmarks;
|
||||||
|
landmarks.reserve(subset_idxs.size() * 2);
|
||||||
|
for (int i = 0; i < subset_idxs.size(); i++) {
|
||||||
|
landmarks.push_back(Read3DLandmarkXYZ(input_data_0, subset_idxs[i][0]));
|
||||||
|
landmarks.push_back(Read3DLandmarkXYZ(input_data_0, subset_idxs[i][1]));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < landmarks.size(); i++) {
|
||||||
|
landmarks[i].z = 1.0;
|
||||||
|
}
|
||||||
|
const float& r = rotation_radians;
|
||||||
|
// clang-format off
|
||||||
|
const Mat3 t_rotation = Mat3(std::cos(r), -std::sin(r), 0.0,
|
||||||
|
std::sin(r), std::cos(r), 0.0,
|
||||||
|
0.0, 0.0, 1.0);
|
||||||
|
const Mat3 t_rotation_inverse =
|
||||||
|
Mat3(std::cos(-r), -std::sin(-r), 0.0,
|
||||||
|
std::sin(-r), std::cos(-r), 0.0,
|
||||||
|
0.0, 0.0, 1.0);
|
||||||
|
// clang-format on
|
||||||
|
for (int i = 0; i < landmarks.size(); i++) {
|
||||||
|
landmarks[i] = t_rotation * landmarks[i];
|
||||||
|
}
|
||||||
|
float3 xy1_max = landmarks[0], xy1_min = landmarks[0];
|
||||||
|
for (int i = 1; i < landmarks.size(); i++) {
|
||||||
|
if (xy1_max.x < landmarks[i].x) xy1_max.x = landmarks[i].x;
|
||||||
|
if (xy1_max.y < landmarks[i].y) xy1_max.y = landmarks[i].y;
|
||||||
|
|
||||||
|
if (xy1_min.x > landmarks[i].x) xy1_min.x = landmarks[i].x;
|
||||||
|
if (xy1_min.y > landmarks[i].y) xy1_min.y = landmarks[i].y;
|
||||||
|
}
|
||||||
|
*crop_width = xy1_max.x - xy1_min.x;
|
||||||
|
*crop_height = xy1_max.y - xy1_min.y;
|
||||||
|
float3 crop_xy1 = xy1_min;
|
||||||
|
crop_xy1.x += xy1_max.x;
|
||||||
|
crop_xy1.y += xy1_max.y;
|
||||||
|
crop_xy1.x /= 2;
|
||||||
|
crop_xy1.y /= 2;
|
||||||
|
crop_xy1 = t_rotation_inverse * crop_xy1;
|
||||||
|
*crop_x = crop_xy1.x;
|
||||||
|
*crop_y = crop_xy1.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void LandmarksToTransformMatrixV2(
|
||||||
|
const LandmarksToTransformMatrixV2Attributes& params,
|
||||||
|
const RuntimeShape& input0_shape, const float* landmarks,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
float rotation_radians = 0.0;
|
||||||
|
EstimateRotationRadians(landmarks, params.left_rotation_idx,
|
||||||
|
params.right_rotation_idx,
|
||||||
|
params.target_rotation_radians, &rotation_radians);
|
||||||
|
float crop_x = 0.0, crop_y = 0.0, crop_width = 0.0, crop_height = 0.0;
|
||||||
|
EstimateCenterAndSize(landmarks, params.subset_idxs, rotation_radians,
|
||||||
|
&crop_x, &crop_y, &crop_width, &crop_height);
|
||||||
|
// Turn off clang formatting to make matrices initialization more readable.
|
||||||
|
// clang-format off
|
||||||
|
Mat4 t = Mat4(1.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 1.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 1.0);
|
||||||
|
const Mat4 t_shift = Mat4(1.0, 0.0, 0.0, crop_x,
|
||||||
|
0.0, 1.0, 0.0, crop_y,
|
||||||
|
0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 1.0);
|
||||||
|
t *= t_shift;
|
||||||
|
const float& r = -rotation_radians;
|
||||||
|
const Mat4 t_rotation = Mat4(std::cos(r), -std::sin(r), 0.0, 0.0,
|
||||||
|
std::sin(r), std::cos(r), 0.0, 0.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 1.0);
|
||||||
|
t *= t_rotation;
|
||||||
|
const float scale_x = params.scale_x * crop_width / params.output_width;
|
||||||
|
const float scale_y = params.scale_y * crop_height / params.output_height;
|
||||||
|
const Mat4 t_scale = Mat4(scale_x, 0.0, 0.0, 0.0,
|
||||||
|
0.0, scale_y, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 1.0);
|
||||||
|
t *= t_scale;
|
||||||
|
const float shift_x = -1.0 * (params.output_width / 2.0);
|
||||||
|
const float shift_y = -1.0 * (params.output_height / 2.0);
|
||||||
|
const Mat4 t_shift2 = Mat4(1.0, 0.0, 0.0, shift_x,
|
||||||
|
0.0, 1.0, 0.0, shift_y,
|
||||||
|
0.0, 0.0, 1.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 1.0);
|
||||||
|
t *= t_shift2;
|
||||||
|
std::memcpy(output_data, t.data.data(), 16 * sizeof(float));
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 3);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = kTensformMatrixShape.x;
|
||||||
|
output_size->data[1] = kTensformMatrixShape.y;
|
||||||
|
output_size->data[2] = kTensformMatrixShape.z;
|
||||||
|
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
LandmarksToTransformMatrixV2Attributes op_params;
|
||||||
|
BHWC output_shape;
|
||||||
|
auto status = tflite::gpu::ParseLandmarksToTransformMatrixV2Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.left_rotation_idx < 0) {
|
||||||
|
context->ReportError(context, "Incorrect left_rotation_idx: %d",
|
||||||
|
op_params.left_rotation_idx);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.right_rotation_idx < 0) {
|
||||||
|
context->ReportError(context, "Incorrect right_rotation_idx: %d",
|
||||||
|
op_params.right_rotation_idx);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.output_height <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect output_height: %d",
|
||||||
|
op_params.output_height);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.output_width <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect output_width: %d",
|
||||||
|
op_params.output_width);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.scale_x <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect scale_x: %d", op_params.scale_x);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_params.scale_y <= 0) {
|
||||||
|
context->ReportError(context, "Incorrect scale_y: %d", op_params.scale_y);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
int counter = 0;
|
||||||
|
for (auto& val : op_params.subset_idxs) {
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
if (val[i] < 0) {
|
||||||
|
context->ReportError(context,
|
||||||
|
"Incorrect subset value: index = %d, value = %d",
|
||||||
|
counter, val[i]);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
counter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 = GetInput(context, node, kDataInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
LandmarksToTransformMatrixV2(
|
||||||
|
op_params, GetTensorShape(input0), GetTensorData<float>(input0),
|
||||||
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace v2
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterLandmarksToTransformMatrixV1() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v1::Prepare,
|
||||||
|
/*.invoke=*/v1::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"Landmarks2TransformMatrix",
|
||||||
|
/*.version=*/1,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
TfLiteRegistration* RegisterLandmarksToTransformMatrixV2() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v2::Prepare,
|
||||||
|
/*.invoke=*/v2::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"Landmarks2TransformMatrix",
|
||||||
|
/*.version=*/2,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,30 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_
|
||||||
|
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterLandmarksToTransformMatrixV1();
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterLandmarksToTransformMatrixV2();
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_LANDMARKS_TO_TRANSFORM_MATRIX_H_
|
298
mediapipe/util/tflite/operations/transform_landmarks.cc
Normal file
298
mediapipe/util/tflite/operations/transform_landmarks.cc
Normal file
|
@ -0,0 +1,298 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kDataInput0Tensor = 0;
|
||||||
|
constexpr int kDataInput1Tensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
float DotProduct(const tflite::gpu::float4& l, const tflite::gpu::float4& r) {
|
||||||
|
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace v1 {
|
||||||
|
|
||||||
|
inline void TransformLandmarks(
|
||||||
|
const tflite::gpu::TransformLandmarksAttributes& params,
|
||||||
|
const tflite::RuntimeShape& input0_shape, const float* landmarks,
|
||||||
|
const tflite::RuntimeShape& input1_shape, const float* transform_matrix,
|
||||||
|
const tflite::RuntimeShape& output_shape, float* output_data) {
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_channels = output_shape.Dims(3);
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.Dims(3) % params.dimensions, 0);
|
||||||
|
TFLITE_CHECK_NE(params.scale, 0);
|
||||||
|
|
||||||
|
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(1),
|
||||||
|
input0_shape.Dims(2),
|
||||||
|
input0_shape.Dims(3)};
|
||||||
|
tflite::RuntimeShape output_shape_with_batch{
|
||||||
|
/*batch=*/1, output_shape.Dims(1), output_shape.Dims(2),
|
||||||
|
output_shape.Dims(3)};
|
||||||
|
|
||||||
|
// Read first two rows of transformation matrix
|
||||||
|
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
|
||||||
|
transform_matrix[2],
|
||||||
|
transform_matrix[3] * params.scale);
|
||||||
|
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
|
||||||
|
transform_matrix[6],
|
||||||
|
transform_matrix[7] * params.scale);
|
||||||
|
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
for (int landmark = 0; landmark < output_channels / params.dimensions;
|
||||||
|
++landmark) {
|
||||||
|
const int offset = Offset(output_shape_with_batch, 0, out_y, out_x,
|
||||||
|
landmark * params.dimensions);
|
||||||
|
|
||||||
|
if (params.dimensions == 2) {
|
||||||
|
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
|
||||||
|
static_cast<float>(0.0),
|
||||||
|
static_cast<float>(1.0));
|
||||||
|
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
|
||||||
|
DotProduct(y_transform, lv));
|
||||||
|
output_data[offset] = transformed.x;
|
||||||
|
output_data[offset + 1] = transformed.y;
|
||||||
|
}
|
||||||
|
if (params.dimensions == 3) {
|
||||||
|
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
|
||||||
|
static_cast<float>(0.0),
|
||||||
|
static_cast<float>(1.0));
|
||||||
|
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
|
||||||
|
DotProduct(y_transform, lv), lv.z);
|
||||||
|
output_data[offset] = transformed.x;
|
||||||
|
output_data[offset + 1] = transformed.y;
|
||||||
|
output_data[offset + 2] = landmarks[offset + 2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
|
||||||
|
output_size->data[0] = input->dims->data[0];
|
||||||
|
output_size->data[1] = input->dims->data[1];
|
||||||
|
output_size->data[2] = input->dims->data[2];
|
||||||
|
output_size->data[3] = input->dims->data[3];
|
||||||
|
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
tflite::gpu::TransformLandmarksAttributes op_params;
|
||||||
|
tflite::gpu::BHWC output_shape;
|
||||||
|
auto status = tflite::gpu::ParseTransformLandmarksV1Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
|
||||||
|
context->ReportError(context, "Incorrect dimensions size: %d",
|
||||||
|
op_params.dimensions);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (op_params.scale == 0) {
|
||||||
|
context->ReportError(context, "Incorrect scale value: %d", op_params.scale);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
const TfLiteTensor* input1 =
|
||||||
|
tflite::GetInput(context, node, kDataInput1Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TransformLandmarks(
|
||||||
|
op_params, tflite::GetTensorShape(input0),
|
||||||
|
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
|
||||||
|
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
|
||||||
|
tflite::GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace v1
|
||||||
|
|
||||||
|
namespace v2 {
|
||||||
|
|
||||||
|
inline void TransformLandmarksV2(
|
||||||
|
const tflite::gpu::TransformLandmarksAttributes& params,
|
||||||
|
const tflite::RuntimeShape& input0_shape, const float* landmarks,
|
||||||
|
const float* transform_matrix, // transformation matrix
|
||||||
|
const tflite::RuntimeShape& output_shape, float* output_data) {
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 3);
|
||||||
|
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 3);
|
||||||
|
const int output_width = output_shape.Dims(1);
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.Dims(2) % params.dimensions, 0);
|
||||||
|
|
||||||
|
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(0),
|
||||||
|
input0_shape.Dims(1),
|
||||||
|
input0_shape.Dims(2)};
|
||||||
|
tflite::RuntimeShape output_shape_with_batch{
|
||||||
|
/*batch=*/1, output_shape.Dims(0), output_shape.Dims(1),
|
||||||
|
output_shape.Dims(2)};
|
||||||
|
|
||||||
|
// Read first two rows of transformation matrix
|
||||||
|
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
|
||||||
|
transform_matrix[2], transform_matrix[3]);
|
||||||
|
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
|
||||||
|
transform_matrix[6], transform_matrix[7]);
|
||||||
|
|
||||||
|
for (int landmark = 0; landmark < output_width; ++landmark) {
|
||||||
|
const int offset = Offset(input_shape_with_batch, 0, 0, landmark, 0);
|
||||||
|
|
||||||
|
if (params.dimensions == 2) {
|
||||||
|
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
|
||||||
|
static_cast<float>(0.0), static_cast<float>(1.0));
|
||||||
|
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
|
||||||
|
DotProduct(y_transform, lv));
|
||||||
|
output_data[offset] = transformed.x;
|
||||||
|
output_data[offset + 1] = transformed.y;
|
||||||
|
}
|
||||||
|
if (params.dimensions == 3) {
|
||||||
|
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
|
||||||
|
static_cast<float>(0.0), static_cast<float>(1.0));
|
||||||
|
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
|
||||||
|
DotProduct(y_transform, lv), lv.z);
|
||||||
|
output_data[offset] = transformed.x;
|
||||||
|
output_data[offset + 1] = transformed.y;
|
||||||
|
output_data[offset + 2] = landmarks[offset + 2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 3);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = input->dims->data[0];
|
||||||
|
output_size->data[1] = input->dims->data[1];
|
||||||
|
output_size->data[2] = input->dims->data[2];
|
||||||
|
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
tflite::gpu::TransformLandmarksAttributes op_params;
|
||||||
|
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
tflite::RuntimeShape runtime_output_shape = tflite::GetTensorShape(output);
|
||||||
|
tflite::gpu::BHWC output_shape(1, runtime_output_shape.Dims(0),
|
||||||
|
runtime_output_shape.Dims(1),
|
||||||
|
runtime_output_shape.Dims(2));
|
||||||
|
auto status = tflite::gpu::ParseTransformLandmarksV2Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
|
||||||
|
context->ReportError(context, "Incorrect dimensions size: %d",
|
||||||
|
op_params.dimensions);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
const TfLiteTensor* input1 =
|
||||||
|
tflite::GetInput(context, node, kDataInput1Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||||
|
|
||||||
|
TransformLandmarksV2(op_params, tflite::GetTensorShape(input0),
|
||||||
|
tflite::GetTensorData<float>(input0),
|
||||||
|
tflite::GetTensorData<float>(input1),
|
||||||
|
tflite::GetTensorShape(output),
|
||||||
|
tflite::GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace v2
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformLandmarksV1() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v1::Prepare,
|
||||||
|
/*.invoke=*/v1::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"TransformLandmarks",
|
||||||
|
/*.version=*/1,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformLandmarksV2() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v2::Prepare,
|
||||||
|
/*.invoke=*/v2::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"TransformLandmarks",
|
||||||
|
/*.version=*/2,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
30
mediapipe/util/tflite/operations/transform_landmarks.h
Normal file
30
mediapipe/util/tflite/operations/transform_landmarks.h
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_
|
||||||
|
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformLandmarksV1();
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformLandmarksV2();
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_LANDMARKS_H_
|
332
mediapipe/util/tflite/operations/transform_tensor_bilinear.cc
Normal file
332
mediapipe/util/tflite/operations/transform_tensor_bilinear.cc
Normal file
|
@ -0,0 +1,332 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kDataInput0Tensor = 0;
|
||||||
|
constexpr int kDataInput1Tensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
float DotProduct(const tflite::gpu::float4& l, const tflite::gpu::float4& r) {
|
||||||
|
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
|
||||||
|
}
|
||||||
|
namespace v1 {
|
||||||
|
|
||||||
|
inline void TransformTensor(
|
||||||
|
const tflite::gpu::TransformTensorBilinearAttributes& params,
|
||||||
|
const tflite::RuntimeShape& input0_shape,
|
||||||
|
const float* input_data_0, // data
|
||||||
|
const tflite::RuntimeShape& input1_shape,
|
||||||
|
const float* input_data_1, // transformation matrix
|
||||||
|
const tflite::RuntimeShape& output_shape, float* output_data) {
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_channels = output_shape.Dims(3);
|
||||||
|
|
||||||
|
const int input_height = input0_shape.Dims(1);
|
||||||
|
const int input_width = input0_shape.Dims(2);
|
||||||
|
const int input_channels = input0_shape.Dims(3);
|
||||||
|
|
||||||
|
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input_height,
|
||||||
|
input_width, input_channels};
|
||||||
|
tflite::RuntimeShape output_shape_with_batch{/*batch=*/1, output_height,
|
||||||
|
output_width, output_channels};
|
||||||
|
|
||||||
|
// Read first two rows of transformation matrix
|
||||||
|
tflite::gpu::float4 x_transform(input_data_1[0], input_data_1[1],
|
||||||
|
input_data_1[2], input_data_1[3]);
|
||||||
|
tflite::gpu::float4 y_transform(input_data_1[4], input_data_1[5],
|
||||||
|
input_data_1[6], input_data_1[7]);
|
||||||
|
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
tflite::gpu::float4 coord(
|
||||||
|
static_cast<float>(out_x), static_cast<float>(out_y),
|
||||||
|
static_cast<float>(0.0), static_cast<float>(1.0));
|
||||||
|
|
||||||
|
// Transformed coordinates.
|
||||||
|
tflite::gpu::float2 tc(DotProduct(x_transform, coord),
|
||||||
|
DotProduct(y_transform, coord));
|
||||||
|
|
||||||
|
bool out_of_bound = tc.x < 0.0 || tc.x > input_width - 1 || tc.y < 0.0 ||
|
||||||
|
tc.y > input_height - 1;
|
||||||
|
|
||||||
|
for (int out_z = 0; out_z < output_channels; ++out_z) {
|
||||||
|
float result = 0;
|
||||||
|
if (!out_of_bound) {
|
||||||
|
// Corners position:
|
||||||
|
// q_11 --- q_21
|
||||||
|
// ---- ----
|
||||||
|
// q_12 --- q_22
|
||||||
|
|
||||||
|
auto ReadValue = [&](int h, int w) -> float {
|
||||||
|
return h < 0 || w < 0 || h >= input_height || w >= input_width
|
||||||
|
? 0
|
||||||
|
: input_data_0[Offset(input_shape_with_batch, 0, h, w,
|
||||||
|
out_z)];
|
||||||
|
};
|
||||||
|
|
||||||
|
float q_11 = ReadValue(floor(tc.y), floor(tc.x));
|
||||||
|
float q_21 = ReadValue(floor(tc.y), floor(tc.x) + 1);
|
||||||
|
float q_12 = ReadValue(floor(tc.y) + 1, floor(tc.x));
|
||||||
|
float q_22 = ReadValue(floor(tc.y) + 1, floor(tc.x) + 1);
|
||||||
|
|
||||||
|
float right_contrib = tc.x - floor(tc.x);
|
||||||
|
float lower_contrib = tc.y - floor(tc.y);
|
||||||
|
|
||||||
|
float upper = (1.0 - right_contrib) * q_11 + right_contrib * q_21;
|
||||||
|
float lower = (1.0 - right_contrib) * q_12 + right_contrib * q_22;
|
||||||
|
|
||||||
|
result = lower_contrib * lower + (1.0 - lower_contrib) * upper;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int out_offset =
|
||||||
|
Offset(output_shape_with_batch, 0, out_y, out_x, out_z);
|
||||||
|
|
||||||
|
output_data[out_offset] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
tflite::gpu::TransformTensorBilinearAttributes op_params;
|
||||||
|
tflite::gpu::BHWC output_shape;
|
||||||
|
auto status = tflite::gpu::ParseTransformTensorBilinearV1Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
const TfLiteTensor* input1 =
|
||||||
|
tflite::GetInput(context, node, kDataInput1Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TransformTensor(
|
||||||
|
op_params, tflite::GetTensorShape(input0),
|
||||||
|
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
|
||||||
|
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
|
||||||
|
tflite::GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace v1
|
||||||
|
|
||||||
|
namespace v2 {
|
||||||
|
|
||||||
|
inline void TransformTensorBilinearV2(
|
||||||
|
const tflite::gpu::TransformTensorBilinearAttributes& params,
|
||||||
|
const tflite::RuntimeShape& input0_shape,
|
||||||
|
const float* input_data_0, // data
|
||||||
|
const tflite::RuntimeShape& input1_shape,
|
||||||
|
const float* input_data_1, // transformation matrix
|
||||||
|
const tflite::RuntimeShape& output_shape, float* output_data) {
|
||||||
|
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_channels = output_shape.Dims(3);
|
||||||
|
|
||||||
|
const int input_height = input0_shape.Dims(1);
|
||||||
|
const int input_width = input0_shape.Dims(2);
|
||||||
|
const int input_channels = input0_shape.Dims(3);
|
||||||
|
|
||||||
|
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input_height,
|
||||||
|
input_width, input_channels};
|
||||||
|
tflite::RuntimeShape output_shape_with_batch{/*batch=*/1, output_height,
|
||||||
|
output_width, output_channels};
|
||||||
|
|
||||||
|
// Read first two rows of transformation matrix
|
||||||
|
tflite::gpu::float4 x_transform(input_data_1[0], input_data_1[1],
|
||||||
|
input_data_1[2], input_data_1[3]);
|
||||||
|
tflite::gpu::float4 y_transform(input_data_1[4], input_data_1[5],
|
||||||
|
input_data_1[6], input_data_1[7]);
|
||||||
|
|
||||||
|
// Align corners correction: T -> S * ( T * A ), where T is a
|
||||||
|
// transformation matrix, and subtruction and addition matrices are:
|
||||||
|
// S A
|
||||||
|
// 1 0 0 -0.5 1 0 0 0.5
|
||||||
|
// 0 1 0 -0.5 0 1 0 0.5
|
||||||
|
// 0 0 1 0 0 0 1 0
|
||||||
|
// 0 0 0 1 0 0 0 1
|
||||||
|
// Transformation matrix column 3 and rows 3, 4 are identity, which makes
|
||||||
|
// the final formula pretty simple and easy to get if doing a manual
|
||||||
|
// multiuplication.
|
||||||
|
x_transform[3] += x_transform[0] * 0.5 + x_transform[1] * 0.5 - 0.5;
|
||||||
|
y_transform[3] += y_transform[0] * 0.5 + y_transform[1] * 0.5 - 0.5;
|
||||||
|
|
||||||
|
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||||
|
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||||
|
tflite::gpu::float4 coord(
|
||||||
|
static_cast<float>(out_x), static_cast<float>(out_y),
|
||||||
|
static_cast<float>(0.0), static_cast<float>(1.0));
|
||||||
|
|
||||||
|
// Transformed coordinates.
|
||||||
|
tflite::gpu::float2 tc(DotProduct(x_transform, coord),
|
||||||
|
DotProduct(y_transform, coord));
|
||||||
|
|
||||||
|
bool out_of_bound = tc.x < 0.0 || tc.x > input_width - 1 || tc.y < 0.0 ||
|
||||||
|
tc.y > input_height - 1;
|
||||||
|
|
||||||
|
for (int out_z = 0; out_z < output_channels; ++out_z) {
|
||||||
|
float result = 0;
|
||||||
|
if (!out_of_bound) {
|
||||||
|
// Corners position:
|
||||||
|
// q_11 --- q_21
|
||||||
|
// ---- ----
|
||||||
|
// q_12 --- q_22
|
||||||
|
|
||||||
|
auto ReadValue = [&](int h, int w) -> float {
|
||||||
|
return h < 0 || w < 0 || h >= input_height || w >= input_width
|
||||||
|
? 0
|
||||||
|
: input_data_0[Offset(input_shape_with_batch, 0, h, w,
|
||||||
|
out_z)];
|
||||||
|
};
|
||||||
|
|
||||||
|
float q_11 = ReadValue(floor(tc.y), floor(tc.x));
|
||||||
|
float q_21 = ReadValue(floor(tc.y), floor(tc.x) + 1);
|
||||||
|
float q_12 = ReadValue(floor(tc.y) + 1, floor(tc.x));
|
||||||
|
float q_22 = ReadValue(floor(tc.y) + 1, floor(tc.x) + 1);
|
||||||
|
|
||||||
|
float right_contrib = tc.x - floor(tc.x);
|
||||||
|
float lower_contrib = tc.y - floor(tc.y);
|
||||||
|
|
||||||
|
float upper = (1.0 - right_contrib) * q_11 + right_contrib * q_21;
|
||||||
|
float lower = (1.0 - right_contrib) * q_12 + right_contrib * q_22;
|
||||||
|
|
||||||
|
result = lower_contrib * lower + (1.0 - lower_contrib) * upper;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int out_offset =
|
||||||
|
Offset(output_shape_with_batch, 0, out_y, out_x, out_z);
|
||||||
|
|
||||||
|
output_data[out_offset] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* input =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
tflite::gpu::TransformTensorBilinearAttributes op_params;
|
||||||
|
tflite::gpu::BHWC output_shape;
|
||||||
|
auto status = tflite::gpu::ParseTransformTensorBilinearV2Attributes(
|
||||||
|
node->custom_initial_data, node->custom_initial_data_size, &op_params,
|
||||||
|
&output_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->ReportError(context, status.message().data());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* input0 =
|
||||||
|
tflite::GetInput(context, node, kDataInput0Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input0 != nullptr);
|
||||||
|
const TfLiteTensor* input1 =
|
||||||
|
tflite::GetInput(context, node, kDataInput1Tensor);
|
||||||
|
TF_LITE_ENSURE(context, input1 != nullptr);
|
||||||
|
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TransformTensorBilinearV2(
|
||||||
|
op_params, tflite::GetTensorShape(input0),
|
||||||
|
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
|
||||||
|
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
|
||||||
|
tflite::GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace v2
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformTensorBilinearV1() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v1::Prepare,
|
||||||
|
/*.invoke=*/v1::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"TransformTensor",
|
||||||
|
/*.version=*/1,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformTensorBilinearV2() {
|
||||||
|
static TfLiteRegistration reg = {
|
||||||
|
/*.init=*/nullptr,
|
||||||
|
/*.free=*/nullptr,
|
||||||
|
/*.prepare=*/v2::Prepare,
|
||||||
|
/*.invoke=*/v2::Eval,
|
||||||
|
/*.profiling_string=*/nullptr,
|
||||||
|
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
|
||||||
|
/*.custom_name=*/"TransformTensorBilinear",
|
||||||
|
/*.version=*/2,
|
||||||
|
};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
30
mediapipe/util/tflite/operations/transform_tensor_bilinear.h
Normal file
30
mediapipe/util/tflite/operations/transform_tensor_bilinear.h
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
// Copyright 2021 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_
|
||||||
|
#define MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tflite_operations {
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformTensorBilinearV1();
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterTransformTensorBilinearV2();
|
||||||
|
|
||||||
|
} // namespace tflite_operations
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_UTIL_TFLITE_OPERATIONS_TRANSFORM_TENSOR_BILINEAR_H_
|
Loading…
Reference in New Issue
Block a user