Tensor: Make tensor not requiring "-x objective-c++" option.
In this case tensor.h is compiled differently for C++ and Objective-C++ that violates ODR (once definition rule). Tensor has no virtual methods conditionally compiled but some Metal-related data members. Instead, unique_ptr to MtlResources that is declared as forward structure is unconditionally defined in the tensor class. MtlResources is defined differently in cc-file only that compiled just once per project so no ODR violation is here. PiperOrigin-RevId: 504029286
This commit is contained in:
parent
921b6a6bef
commit
1124569c29
|
@ -53,14 +53,6 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "audio_to_tensor_calculator",
|
name = "audio_to_tensor_calculator",
|
||||||
srcs = ["audio_to_tensor_calculator.cc"],
|
srcs = ["audio_to_tensor_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# b/215212850
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc",
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":audio_to_tensor_calculator_cc_proto",
|
":audio_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
@ -161,14 +153,6 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "feedback_tensors_calculator",
|
name = "feedback_tensors_calculator",
|
||||||
srcs = ["feedback_tensors_calculator.cc"],
|
srcs = ["feedback_tensors_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# b/215212850
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc",
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":feedback_tensors_calculator_cc_proto",
|
":feedback_tensors_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
@ -207,14 +191,6 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bert_preprocessor_calculator",
|
name = "bert_preprocessor_calculator",
|
||||||
srcs = ["bert_preprocessor_calculator.cc"],
|
srcs = ["bert_preprocessor_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":bert_preprocessor_calculator_cc_proto",
|
":bert_preprocessor_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
@ -267,14 +243,6 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "regex_preprocessor_calculator",
|
name = "regex_preprocessor_calculator",
|
||||||
srcs = ["regex_preprocessor_calculator.cc"],
|
srcs = ["regex_preprocessor_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":regex_preprocessor_calculator_cc_proto",
|
":regex_preprocessor_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
@ -316,14 +284,6 @@ cc_test(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "text_to_tensor_calculator",
|
name = "text_to_tensor_calculator",
|
||||||
srcs = ["text_to_tensor_calculator.cc"],
|
srcs = ["text_to_tensor_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
@ -414,14 +374,6 @@ cc_library(
|
||||||
name = "inference_calculator_interface",
|
name = "inference_calculator_interface",
|
||||||
srcs = ["inference_calculator.cc"],
|
srcs = ["inference_calculator.cc"],
|
||||||
hdrs = ["inference_calculator.h"],
|
hdrs = ["inference_calculator.h"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_cc_proto",
|
":inference_calculator_cc_proto",
|
||||||
":inference_calculator_options_lib",
|
":inference_calculator_options_lib",
|
||||||
|
@ -495,6 +447,7 @@ cc_library(
|
||||||
tags = ["ios"],
|
tags = ["ios"],
|
||||||
deps = [
|
deps = [
|
||||||
"inference_calculator_interface",
|
"inference_calculator_interface",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/gpu:MPPMetalUtil",
|
"//mediapipe/gpu:MPPMetalUtil",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
@ -513,14 +466,6 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "inference_runner",
|
name = "inference_runner",
|
||||||
hdrs = ["inference_runner.h"],
|
hdrs = ["inference_runner.h"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
@ -532,14 +477,6 @@ cc_library(
|
||||||
name = "inference_interpreter_delegate_runner",
|
name = "inference_interpreter_delegate_runner",
|
||||||
srcs = ["inference_interpreter_delegate_runner.cc"],
|
srcs = ["inference_interpreter_delegate_runner.cc"],
|
||||||
hdrs = ["inference_interpreter_delegate_runner.h"],
|
hdrs = ["inference_interpreter_delegate_runner.h"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":inference_runner",
|
":inference_runner",
|
||||||
"//mediapipe/framework:mediapipe_profiling",
|
"//mediapipe/framework:mediapipe_profiling",
|
||||||
|
@ -561,14 +498,6 @@ cc_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"inference_calculator_cpu.cc",
|
"inference_calculator_cpu.cc",
|
||||||
],
|
],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_interface",
|
":inference_calculator_interface",
|
||||||
":inference_calculator_utils",
|
":inference_calculator_utils",
|
||||||
|
@ -607,14 +536,6 @@ cc_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"inference_calculator_xnnpack.cc",
|
"inference_calculator_xnnpack.cc",
|
||||||
],
|
],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_interface",
|
":inference_calculator_interface",
|
||||||
":inference_calculator_utils",
|
":inference_calculator_utils",
|
||||||
|
|
|
@ -36,6 +36,10 @@
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -376,7 +380,7 @@ class MetalProcessor : public ImageToTensorConverter {
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
||||||
const auto& buffer_view =
|
const auto& buffer_view =
|
||||||
output_tensor.GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(output_tensor, command_buffer);
|
||||||
MP_RETURN_IF_ERROR(extractor_->Execute(
|
MP_RETURN_IF_ERROR(extractor_->Execute(
|
||||||
texture, roi,
|
texture, roi,
|
||||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||||
#include "mediapipe/gpu/gpu_buffer.h"
|
#include "mediapipe/gpu/gpu_buffer.h"
|
||||||
|
@ -150,11 +152,12 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
|
||||||
command_buffer.label = @"InferenceCalculator";
|
command_buffer.label = @"InferenceCalculator";
|
||||||
// Explicit copy input with conversion float 32 bits to 16 bits.
|
// Explicit copy input with conversion float 32 bits to 16 bits.
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer);
|
auto input_view =
|
||||||
|
MtlBufferView::GetReadView(input_tensors[i], command_buffer);
|
||||||
// Reshape tensor.
|
// Reshape tensor.
|
||||||
tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape());
|
tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape());
|
||||||
auto gpu_buffer_view =
|
auto gpu_buffer_view =
|
||||||
gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(*gpu_buffers_in_[i], command_buffer);
|
||||||
id<MTLComputeCommandEncoder> input_encoder =
|
id<MTLComputeCommandEncoder> input_encoder =
|
||||||
[command_buffer computeCommandEncoder];
|
[command_buffer computeCommandEncoder];
|
||||||
[converter_to_BPHWC4_ convertWithEncoder:input_encoder
|
[converter_to_BPHWC4_ convertWithEncoder:input_encoder
|
||||||
|
@ -174,9 +177,10 @@ absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
|
||||||
output_shapes_[i]);
|
output_shapes_[i]);
|
||||||
// Reshape tensor.
|
// Reshape tensor.
|
||||||
tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]);
|
tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]);
|
||||||
auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer);
|
auto read_view =
|
||||||
|
MtlBufferView::GetReadView(*gpu_buffers_out_[i], command_buffer);
|
||||||
auto write_view =
|
auto write_view =
|
||||||
output_tensors->at(i).GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(output_tensors->at(i), command_buffer);
|
||||||
id<MTLComputeCommandEncoder> output_encoder =
|
id<MTLComputeCommandEncoder> output_encoder =
|
||||||
[command_buffer computeCommandEncoder];
|
[command_buffer computeCommandEncoder];
|
||||||
[converter_from_BPHWC4_ convertWithEncoder:output_encoder
|
[converter_from_BPHWC4_ convertWithEncoder:output_encoder
|
||||||
|
@ -258,7 +262,7 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters(
|
||||||
: Tensor::ElementType::kFloat32,
|
: Tensor::ElementType::kFloat32,
|
||||||
Tensor::Shape{dims}));
|
Tensor::Shape{dims}));
|
||||||
auto buffer_view =
|
auto buffer_view =
|
||||||
gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice);
|
MtlBufferView::GetWriteView(*gpu_buffers_in_[i], gpu_helper_.mtlDevice);
|
||||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||||
delegate_.get(), input_indices[i], buffer_view.buffer()),
|
delegate_.get(), input_indices[i], buffer_view.buffer()),
|
||||||
true);
|
true);
|
||||||
|
@ -286,8 +290,8 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters(
|
||||||
Tensor::Shape{dims}));
|
Tensor::Shape{dims}));
|
||||||
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
|
||||||
delegate_.get(), output_indices[i],
|
delegate_.get(), output_indices[i],
|
||||||
gpu_buffers_out_[i]
|
MtlBufferView::GetWriteView(*gpu_buffers_out_[i],
|
||||||
->GetMtlBufferWriteView(gpu_helper_.mtlDevice)
|
gpu_helper_.mtlDevice)
|
||||||
.buffer()),
|
.buffer()),
|
||||||
true);
|
true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
@ -304,7 +305,7 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
|
||||||
id<MTLTexture> src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input];
|
id<MTLTexture> src_texture = [gpu_helper_ metalTextureWithGpuBuffer:input];
|
||||||
[compute_encoder setTexture:src_texture atIndex:0];
|
[compute_encoder setTexture:src_texture atIndex:0];
|
||||||
auto output_view =
|
auto output_view =
|
||||||
output_tensors->at(0).GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(output_tensors->at(0), command_buffer);
|
||||||
[compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1];
|
[compute_encoder setBuffer:output_view.buffer() offset:0 atIndex:1];
|
||||||
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1);
|
MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1);
|
||||||
MTLSize threadgroups =
|
MTLSize threadgroups =
|
||||||
|
|
|
@ -41,6 +41,7 @@
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
@ -536,10 +537,11 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
||||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||||
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||||
auto command_buffer = [gpu_helper_ commandBuffer];
|
auto command_buffer = [gpu_helper_ commandBuffer];
|
||||||
auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()]
|
auto src_buffer = MtlBufferView::GetReadView(
|
||||||
.GetMtlBufferReadView(command_buffer);
|
input_tensors[tensor_mapping_.anchors_tensor_index()],
|
||||||
|
command_buffer);
|
||||||
auto dest_buffer =
|
auto dest_buffer =
|
||||||
raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(*raw_anchors_buffer_, command_buffer);
|
||||||
id<MTLBlitCommandEncoder> blit_command =
|
id<MTLBlitCommandEncoder> blit_command =
|
||||||
[command_buffer blitCommandEncoder];
|
[command_buffer blitCommandEncoder];
|
||||||
[blit_command copyFromBuffer:src_buffer.buffer()
|
[blit_command copyFromBuffer:src_buffer.buffer()
|
||||||
|
@ -571,15 +573,16 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
||||||
[command_encoder setComputePipelineState:decode_program_];
|
[command_encoder setComputePipelineState:decode_program_];
|
||||||
{
|
{
|
||||||
auto scored_boxes_view =
|
auto scored_boxes_view =
|
||||||
scored_boxes_buffer_->GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(*scored_boxes_buffer_, command_buffer);
|
||||||
auto decoded_boxes_view =
|
auto decoded_boxes_view =
|
||||||
decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer);
|
MtlBufferView::GetWriteView(*decoded_boxes_buffer_, command_buffer);
|
||||||
[command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0];
|
[command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0];
|
||||||
auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()]
|
auto input0_view = MtlBufferView::GetReadView(
|
||||||
.GetMtlBufferReadView(command_buffer);
|
input_tensors[tensor_mapping_.detections_tensor_index()],
|
||||||
|
command_buffer);
|
||||||
[command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1];
|
[command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1];
|
||||||
auto raw_anchors_view =
|
auto raw_anchors_view =
|
||||||
raw_anchors_buffer_->GetMtlBufferReadView(command_buffer);
|
MtlBufferView::GetReadView(*raw_anchors_buffer_, command_buffer);
|
||||||
[command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2];
|
[command_encoder setBuffer:raw_anchors_view.buffer() offset:0 atIndex:2];
|
||||||
MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1);
|
MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1);
|
||||||
MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
|
MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
|
||||||
|
@ -588,8 +591,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
||||||
|
|
||||||
[command_encoder setComputePipelineState:score_program_];
|
[command_encoder setComputePipelineState:score_program_];
|
||||||
[command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0];
|
[command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0];
|
||||||
auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
|
auto input1_view = MtlBufferView::GetReadView(
|
||||||
.GetMtlBufferReadView(command_buffer);
|
input_tensors[tensor_mapping_.scores_tensor_index()], command_buffer);
|
||||||
[command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1];
|
[command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1];
|
||||||
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
|
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
|
||||||
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
|
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
|
||||||
|
|
|
@ -53,6 +53,7 @@
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalKit/MetalKit.h>
|
#import <MetalKit/MetalKit.h>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
#import "mediapipe/gpu/MPPMetalHelper.h"
|
#import "mediapipe/gpu/MPPMetalHelper.h"
|
||||||
#include "mediapipe/gpu/MPPMetalUtil.h"
|
#include "mediapipe/gpu/MPPMetalUtil.h"
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
@ -485,7 +486,8 @@ absl::Status TensorsToSegmentationCalculator::ProcessGpu(
|
||||||
[command_buffer computeCommandEncoder];
|
[command_buffer computeCommandEncoder];
|
||||||
[command_encoder setComputePipelineState:mask_program_];
|
[command_encoder setComputePipelineState:mask_program_];
|
||||||
|
|
||||||
auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer);
|
auto read_view =
|
||||||
|
MtlBufferView::GetReadView(input_tensors[0], command_buffer);
|
||||||
[command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0];
|
[command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0];
|
||||||
|
|
||||||
mediapipe::GpuBuffer small_mask_buffer = [metal_helper_
|
mediapipe::GpuBuffer small_mask_buffer = [metal_helper_
|
||||||
|
|
|
@ -431,7 +431,10 @@ cc_library(
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"tensor.h",
|
"tensor.h",
|
||||||
"//mediapipe/framework/formats/tensor:internal.h",
|
"//mediapipe/framework/formats/tensor:internal.h",
|
||||||
],
|
] + select({
|
||||||
|
"//mediapipe:ios": ["tensor_mtl_buffer_view.h"],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
copts = select({
|
copts = select({
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
"-x objective-c++",
|
"-x objective-c++",
|
||||||
|
|
|
@ -25,8 +25,11 @@
|
||||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
|
#import <Metal/Metal.h>
|
||||||
#include <mach/mach_init.h>
|
#include <mach/mach_init.h>
|
||||||
#include <mach/vm_map.h>
|
#include <mach/vm_map.h>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/tensor_mtl_buffer_view.h"
|
||||||
#else
|
#else
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
@ -61,6 +64,12 @@ int BhwcDepthFromShape(const Tensor::Shape& shape) {
|
||||||
// 3) pad/"unpad" the bitmap after transfer CPU <-> GPU
|
// 3) pad/"unpad" the bitmap after transfer CPU <-> GPU
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
|
// No ODR violation here because this file compiled just once per project.
|
||||||
|
struct MtlResources {
|
||||||
|
id<MTLCommandBuffer> command_buffer = nil;
|
||||||
|
id<MTLDevice> device = nil;
|
||||||
|
id<MTLBuffer> metal_buffer = nil;
|
||||||
|
};
|
||||||
namespace {
|
namespace {
|
||||||
// MTLBuffer can use existing properly aligned and allocated CPU memory.
|
// MTLBuffer can use existing properly aligned and allocated CPU memory.
|
||||||
size_t AlignToPageSize(size_t size) {
|
size_t AlignToPageSize(size_t size) {
|
||||||
|
@ -83,52 +92,56 @@ void DeallocateVirtualMemory(void* pointer, size_t size) {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Tensor::MtlBufferView Tensor::GetMtlBufferReadView(
|
void MtlBufferView::AllocateMtlBuffer(const Tensor& tensor,
|
||||||
id<MTLCommandBuffer> command_buffer) const {
|
id<MTLDevice> device) {
|
||||||
LOG_IF(FATAL, valid_ == kValidNone)
|
tensor.mtl_resources_->device = device;
|
||||||
|
if (!tensor.cpu_buffer_) {
|
||||||
|
// It also means that the metal buffer is not allocated yet.
|
||||||
|
tensor.cpu_buffer_ = AllocateVirtualMemory(tensor.bytes());
|
||||||
|
}
|
||||||
|
if (!tensor.mtl_resources_->metal_buffer) {
|
||||||
|
tensor.mtl_resources_->metal_buffer = [tensor.mtl_resources_->device
|
||||||
|
newBufferWithBytesNoCopy:tensor.cpu_buffer_
|
||||||
|
length:AlignToPageSize(tensor.bytes())
|
||||||
|
options:MTLResourceStorageModeShared |
|
||||||
|
MTLResourceCPUCacheModeDefaultCache
|
||||||
|
deallocator:^(void* pointer, NSUInteger length) {
|
||||||
|
DeallocateVirtualMemory(pointer, length);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MtlBufferView MtlBufferView::GetReadView(const Tensor& tensor,
|
||||||
|
id<MTLCommandBuffer> command_buffer) {
|
||||||
|
LOG_IF(FATAL, tensor.valid_ == Tensor::kValidNone)
|
||||||
<< "Tensor must be written prior to read from.";
|
<< "Tensor must be written prior to read from.";
|
||||||
LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidMetalBuffer)))
|
LOG_IF(FATAL,
|
||||||
|
!(tensor.valid_ & (Tensor::kValidCpu | Tensor::kValidMetalBuffer)))
|
||||||
<< "Tensor conversion between different GPU resources is not supported "
|
<< "Tensor conversion between different GPU resources is not supported "
|
||||||
"yet.";
|
"yet.";
|
||||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
auto lock(absl::make_unique<absl::MutexLock>(&tensor.view_mutex_));
|
||||||
valid_ |= kValidMetalBuffer;
|
tensor.valid_ |= Tensor::kValidMetalBuffer;
|
||||||
AllocateMtlBuffer([command_buffer device]);
|
AllocateMtlBuffer(tensor, [command_buffer device]);
|
||||||
return {metal_buffer_, std::move(lock)};
|
return {tensor.mtl_resources_->metal_buffer, std::move(lock)};
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::MtlBufferView Tensor::GetMtlBufferWriteView(
|
MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor,
|
||||||
id<MTLCommandBuffer> command_buffer) const {
|
id<MTLCommandBuffer> command_buffer) {
|
||||||
// Don't overwrite command buffer at which the metal buffer has been written
|
// Don't overwrite command buffer at which the metal buffer has been written
|
||||||
// so we can wait until completed.
|
// so we can wait until completed.
|
||||||
command_buffer_ = command_buffer;
|
tensor.mtl_resources_->command_buffer = command_buffer;
|
||||||
return GetMtlBufferWriteView([command_buffer device]);
|
return GetWriteView(tensor, [command_buffer device]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::MtlBufferView Tensor::GetMtlBufferWriteView(
|
MtlBufferView MtlBufferView::GetWriteView(const Tensor& tensor,
|
||||||
id<MTLDevice> device) const {
|
id<MTLDevice> device) {
|
||||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
auto lock(absl::make_unique<absl::MutexLock>(&tensor.view_mutex_));
|
||||||
valid_ = kValidMetalBuffer;
|
tensor.valid_ = Tensor::kValidMetalBuffer;
|
||||||
AllocateMtlBuffer(device);
|
AllocateMtlBuffer(tensor, device);
|
||||||
return {metal_buffer_, std::move(lock)};
|
return {tensor.mtl_resources_->metal_buffer, std::move(lock)};
|
||||||
}
|
|
||||||
|
|
||||||
void Tensor::AllocateMtlBuffer(id<MTLDevice> device) const {
|
|
||||||
device_ = device;
|
|
||||||
if (!cpu_buffer_) {
|
|
||||||
// It also means that the metal buffer is not allocated yet.
|
|
||||||
cpu_buffer_ = AllocateVirtualMemory(bytes());
|
|
||||||
}
|
|
||||||
if (!metal_buffer_) {
|
|
||||||
metal_buffer_ =
|
|
||||||
[device_ newBufferWithBytesNoCopy:cpu_buffer_
|
|
||||||
length:AlignToPageSize(bytes())
|
|
||||||
options:MTLResourceStorageModeShared |
|
|
||||||
MTLResourceCPUCacheModeDefaultCache
|
|
||||||
deallocator:^(void* pointer, NSUInteger length) {
|
|
||||||
DeallocateVirtualMemory(pointer, length);
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
struct MtlResources {};
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
@ -379,6 +392,9 @@ Tensor& Tensor::operator=(Tensor&& src) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor(Tensor&& src) { Move(&src); }
|
||||||
|
Tensor::~Tensor() { Invalidate(); }
|
||||||
|
|
||||||
void Tensor::Move(Tensor* src) {
|
void Tensor::Move(Tensor* src) {
|
||||||
valid_ = src->valid_;
|
valid_ = src->valid_;
|
||||||
src->valid_ = kValidNone;
|
src->valid_ = kValidNone;
|
||||||
|
@ -388,15 +404,7 @@ void Tensor::Move(Tensor* src) {
|
||||||
cpu_buffer_ = src->cpu_buffer_;
|
cpu_buffer_ = src->cpu_buffer_;
|
||||||
src->cpu_buffer_ = nullptr;
|
src->cpu_buffer_ = nullptr;
|
||||||
ahwb_tracking_key_ = src->ahwb_tracking_key_;
|
ahwb_tracking_key_ = src->ahwb_tracking_key_;
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
mtl_resources_ = std::move(src->mtl_resources_);
|
||||||
device_ = src->device_;
|
|
||||||
src->device_ = nil;
|
|
||||||
command_buffer_ = src->command_buffer_;
|
|
||||||
src->command_buffer_ = nil;
|
|
||||||
metal_buffer_ = src->metal_buffer_;
|
|
||||||
src->metal_buffer_ = nil;
|
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
|
||||||
|
|
||||||
MoveAhwbStuff(src);
|
MoveAhwbStuff(src);
|
||||||
|
|
||||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
|
@ -415,12 +423,15 @@ void Tensor::Move(Tensor* src) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor::Tensor(ElementType element_type, const Shape& shape)
|
Tensor::Tensor(ElementType element_type, const Shape& shape)
|
||||||
: element_type_(element_type), shape_(shape) {}
|
: element_type_(element_type),
|
||||||
|
shape_(shape),
|
||||||
|
mtl_resources_(std::make_unique<MtlResources>()) {}
|
||||||
Tensor::Tensor(ElementType element_type, const Shape& shape,
|
Tensor::Tensor(ElementType element_type, const Shape& shape,
|
||||||
const QuantizationParameters& quantization_parameters)
|
const QuantizationParameters& quantization_parameters)
|
||||||
: element_type_(element_type),
|
: element_type_(element_type),
|
||||||
shape_(shape),
|
shape_(shape),
|
||||||
quantization_parameters_(quantization_parameters) {}
|
quantization_parameters_(quantization_parameters),
|
||||||
|
mtl_resources_(std::make_unique<MtlResources>()) {}
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
void Tensor::Invalidate() {
|
void Tensor::Invalidate() {
|
||||||
|
@ -432,13 +443,16 @@ void Tensor::Invalidate() {
|
||||||
absl::MutexLock lock(&view_mutex_);
|
absl::MutexLock lock(&view_mutex_);
|
||||||
// If memory is allocated and not owned by the metal buffer.
|
// If memory is allocated and not owned by the metal buffer.
|
||||||
// TODO: Re-design cpu buffer memory management.
|
// TODO: Re-design cpu buffer memory management.
|
||||||
if (cpu_buffer_ && !metal_buffer_) {
|
if (cpu_buffer_ && !mtl_resources_->metal_buffer) {
|
||||||
DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes()));
|
DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes()));
|
||||||
}
|
}
|
||||||
metal_buffer_ = nil;
|
|
||||||
command_buffer_ = nil;
|
|
||||||
device_ = nil;
|
|
||||||
cpu_buffer_ = nullptr;
|
cpu_buffer_ = nullptr;
|
||||||
|
// This becomes NULL if the tensor is moved.
|
||||||
|
if (mtl_resources_) {
|
||||||
|
mtl_resources_->metal_buffer = nil;
|
||||||
|
mtl_resources_->command_buffer = nil;
|
||||||
|
mtl_resources_->device = nil;
|
||||||
|
}
|
||||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||||
// Don't need to wait for the resource to be deleted bacause if will be
|
// Don't need to wait for the resource to be deleted bacause if will be
|
||||||
// released on last reference deletion inside the OpenGL driver.
|
// released on last reference deletion inside the OpenGL driver.
|
||||||
|
@ -532,10 +546,11 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
||||||
// GPU-to-CPU synchronization and read-back.
|
// GPU-to-CPU synchronization and read-back.
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
if (valid_ & kValidMetalBuffer) {
|
if (valid_ & kValidMetalBuffer) {
|
||||||
LOG_IF(FATAL, !command_buffer_) << "Metal -> CPU synchronization "
|
LOG_IF(FATAL, !mtl_resources_->command_buffer)
|
||||||
"requires MTLCommandBuffer to be set.";
|
<< "Metal -> CPU synchronization "
|
||||||
if (command_buffer_) {
|
"requires MTLCommandBuffer to be set.";
|
||||||
[command_buffer_ waitUntilCompleted];
|
if (mtl_resources_->command_buffer) {
|
||||||
|
[mtl_resources_->command_buffer waitUntilCompleted];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
|
@ -29,9 +29,6 @@
|
||||||
#include "mediapipe/framework/formats/tensor/internal.h"
|
#include "mediapipe/framework/formats/tensor/internal.h"
|
||||||
#include "mediapipe/framework/port.h"
|
#include "mediapipe/framework/port.h"
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
|
||||||
#import <Metal/Metal.h>
|
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
|
||||||
#ifndef MEDIAPIPE_NO_JNI
|
#ifndef MEDIAPIPE_NO_JNI
|
||||||
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
||||||
#define MEDIAPIPE_TENSOR_USE_AHWB 1
|
#define MEDIAPIPE_TENSOR_USE_AHWB 1
|
||||||
|
@ -66,7 +63,6 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
// Tensor is a container of multi-dimensional data that supports sharing the
|
// Tensor is a container of multi-dimensional data that supports sharing the
|
||||||
// content across different backends and APIs, currently: CPU / Metal / OpenGL.
|
// content across different backends and APIs, currently: CPU / Metal / OpenGL.
|
||||||
// Texture2DView is limited to 4 dimensions.
|
// Texture2DView is limited to 4 dimensions.
|
||||||
|
@ -91,6 +87,7 @@ namespace mediapipe {
|
||||||
// float* pointer = view.buffer<float>();
|
// float* pointer = view.buffer<float>();
|
||||||
// ...reading the cpu memory...
|
// ...reading the cpu memory...
|
||||||
|
|
||||||
|
struct MtlResources;
|
||||||
class Tensor {
|
class Tensor {
|
||||||
class View {
|
class View {
|
||||||
public:
|
public:
|
||||||
|
@ -144,9 +141,9 @@ class Tensor {
|
||||||
Tensor(const Tensor&) = delete;
|
Tensor(const Tensor&) = delete;
|
||||||
Tensor& operator=(const Tensor&) = delete;
|
Tensor& operator=(const Tensor&) = delete;
|
||||||
// Move-only.
|
// Move-only.
|
||||||
Tensor(Tensor&& src) { Move(&src); }
|
Tensor(Tensor&& src);
|
||||||
Tensor& operator=(Tensor&&);
|
Tensor& operator=(Tensor&&);
|
||||||
~Tensor() { Invalidate(); }
|
~Tensor();
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CpuView : public View {
|
class CpuView : public View {
|
||||||
|
@ -182,33 +179,6 @@ class Tensor {
|
||||||
uint64_t source_location_hash =
|
uint64_t source_location_hash =
|
||||||
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
|
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
|
||||||
|
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
|
||||||
// TODO: id<MTLBuffer> vs. MtlBufferView.
|
|
||||||
class MtlBufferView : public View {
|
|
||||||
public:
|
|
||||||
id<MTLBuffer> buffer() const { return buffer_; }
|
|
||||||
MtlBufferView(MtlBufferView&& src)
|
|
||||||
: View(std::move(src)), buffer_(src.buffer_) {
|
|
||||||
src.buffer_ = nil;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
friend class Tensor;
|
|
||||||
MtlBufferView(id<MTLBuffer> buffer, std::unique_ptr<absl::MutexLock>&& lock)
|
|
||||||
: View(std::move(lock)), buffer_(buffer) {}
|
|
||||||
id<MTLBuffer> buffer_;
|
|
||||||
};
|
|
||||||
// The command buffer status is checked for completeness if GPU-to-CPU
|
|
||||||
// synchronization is required.
|
|
||||||
// TODO: Design const and non-const view acquiring.
|
|
||||||
MtlBufferView GetMtlBufferReadView(id<MTLCommandBuffer> command_buffer) const;
|
|
||||||
MtlBufferView GetMtlBufferWriteView(
|
|
||||||
id<MTLCommandBuffer> command_buffer) const;
|
|
||||||
// Allocate new buffer.
|
|
||||||
// TODO: GPU-to-CPU design considerations.
|
|
||||||
MtlBufferView GetMtlBufferWriteView(id<MTLDevice> device) const;
|
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
|
||||||
|
|
||||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||||
using FinishingFunc = std::function<bool(bool)>;
|
using FinishingFunc = std::function<bool(bool)>;
|
||||||
class AHardwareBufferView : public View {
|
class AHardwareBufferView : public View {
|
||||||
|
@ -372,6 +342,7 @@ class Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class MtlBufferView;
|
||||||
void Move(Tensor*);
|
void Move(Tensor*);
|
||||||
void Invalidate();
|
void Invalidate();
|
||||||
|
|
||||||
|
@ -396,12 +367,9 @@ class Tensor {
|
||||||
|
|
||||||
mutable void* cpu_buffer_ = nullptr;
|
mutable void* cpu_buffer_ = nullptr;
|
||||||
void AllocateCpuBuffer() const;
|
void AllocateCpuBuffer() const;
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
// Forward declaration of the MtlResources provides compile-time verification
|
||||||
mutable id<MTLCommandBuffer> command_buffer_ = nil;
|
// of ODR if this header includes any actual code that uses MtlResources.
|
||||||
mutable id<MTLDevice> device_ = nil;
|
mutable std::unique_ptr<MtlResources> mtl_resources_;
|
||||||
mutable id<MTLBuffer> metal_buffer_ = nil;
|
|
||||||
void AllocateMtlBuffer(id<MTLDevice> device) const;
|
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
|
||||||
|
|
||||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||||
mutable AHardwareBuffer* ahwb_ = nullptr;
|
mutable AHardwareBuffer* ahwb_ = nullptr;
|
||||||
|
|
61
mediapipe/framework/formats/tensor_mtl_buffer_view.h
Normal file
61
mediapipe/framework/formats/tensor_mtl_buffer_view.h
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2020 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_
|
||||||
|
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_
|
||||||
|
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <numeric>
|
||||||
|
#include <tuple>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
class MtlBufferView : public Tensor::View {
|
||||||
|
public:
|
||||||
|
// The command buffer status is checked for completeness if GPU-to-CPU
|
||||||
|
// synchronization is required.
|
||||||
|
static MtlBufferView GetReadView(const Tensor& tensor,
|
||||||
|
id<MTLCommandBuffer> command_buffer);
|
||||||
|
static MtlBufferView GetWriteView(const Tensor& tensor,
|
||||||
|
id<MTLCommandBuffer> command_buffer);
|
||||||
|
static MtlBufferView GetWriteView(const Tensor& tensor, id<MTLDevice> device);
|
||||||
|
|
||||||
|
id<MTLBuffer> buffer() const { return buffer_; }
|
||||||
|
MtlBufferView(MtlBufferView&& src)
|
||||||
|
: Tensor::View(std::move(src)), buffer_(src.buffer_) {
|
||||||
|
src.buffer_ = nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend class Tensor;
|
||||||
|
static void AllocateMtlBuffer(const Tensor& tensor, id<MTLDevice> device);
|
||||||
|
MtlBufferView(id<MTLBuffer> buffer, std::unique_ptr<absl::MutexLock>&& lock)
|
||||||
|
: Tensor::View(std::move(lock)), buffer_(buffer) {}
|
||||||
|
id<MTLBuffer> buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_MTL_BUFFER_VIEW_H_
|
|
@ -79,14 +79,6 @@ mediapipe_proto_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "score_calibration_calculator",
|
name = "score_calibration_calculator",
|
||||||
srcs = ["score_calibration_calculator.cc"],
|
srcs = ["score_calibration_calculator.cc"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
":score_calibration_calculator_cc_proto",
|
":score_calibration_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
|
|
@ -28,14 +28,6 @@ cc_library(
|
||||||
name = "classification_postprocessing_graph",
|
name = "classification_postprocessing_graph",
|
||||||
srcs = ["classification_postprocessing_graph.cc"],
|
srcs = ["classification_postprocessing_graph.cc"],
|
||||||
hdrs = ["classification_postprocessing_graph.h"],
|
hdrs = ["classification_postprocessing_graph.h"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
|
@ -148,14 +140,6 @@ cc_library(
|
||||||
name = "text_preprocessing_graph",
|
name = "text_preprocessing_graph",
|
||||||
srcs = ["text_preprocessing_graph.cc"],
|
srcs = ["text_preprocessing_graph.cc"],
|
||||||
hdrs = ["text_preprocessing_graph.h"],
|
hdrs = ["text_preprocessing_graph.h"],
|
||||||
copts = select({
|
|
||||||
# TODO: fix tensor.h not to require this, if possible
|
|
||||||
"//mediapipe:apple": [
|
|
||||||
"-x objective-c++",
|
|
||||||
"-fobjc-arc", # enable reference-counting
|
|
||||||
],
|
|
||||||
"//conditions:default": [],
|
|
||||||
}),
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
|
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
|
||||||
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user