Fix bug: override output tensor data buffer in tflite_converter_calculator and tflite_inference_calculator and tflite_tensors_to_detections_calculator. This bug cause output wrong object detection in muti-threading. It make output video unstable (different output video with same input)

This commit is contained in:
thuan 2020-09-03 18:32:30 +08:00
parent 1db91b550a
commit 328905ec6b
7 changed files with 157 additions and 37 deletions

View File

@ -196,6 +196,12 @@ cc_test(
], ],
) )
cc_library(
name = "util",
hdrs = ["util.h"],
alwayslink = 1,
)
selects.config_setting_group( selects.config_setting_group(
name = "gpu_inference_disabled", name = "gpu_inference_disabled",
match_any = [ match_any = [
@ -222,6 +228,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_inference_calculator_cc_proto", ":tflite_inference_calculator_cc_proto",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -288,6 +295,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
":util",
":tflite_converter_calculator_cc_proto", ":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -407,6 +415,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
"//mediapipe/util/tflite:config", "//mediapipe/util/tflite:config",
":tflite_tensors_to_detections_calculator_cc_proto", ":tflite_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",

View File

@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
@ -207,7 +208,7 @@ bool ShouldUseGpu(CC* cc) {
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kTensorsTag)) { if (cc->Outputs().HasTag(kTensorsTag)) {
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
} }
if (cc->Outputs().HasTag(kTensorsGpuTag)) { if (cc->Outputs().HasTag(kTensorsGpuTag)) {
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>(); cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
@ -368,8 +369,9 @@ bool ShouldUseGpu(CC* cc) {
} }
} }
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>(); auto output_tensors = absl::make_unique<std::vector<TfLiteTensorContainer>>();
output_tensors->emplace_back(*tensor); TfLiteTensorContainer tensor_out(*tensor);
output_tensors->emplace_back(tensor_out);
cc->Outputs() cc->Outputs()
.Tag(kTensorsTag) .Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
@ -400,8 +402,9 @@ bool ShouldUseGpu(CC* cc) {
MP_RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_ptr)); MP_RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_ptr));
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>(); auto output_tensors = absl::make_unique<std::vector<TfLiteTensorContainer>>();
output_tensors->emplace_back(*tensor); TfLiteTensorContainer tensor_out(*tensor);
output_tensors->emplace_back(tensor_out);
cc->Outputs() cc->Outputs()
.Tag(kTensorsTag) .Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
@ -439,6 +442,8 @@ bool ShouldUseGpu(CC* cc) {
[this, &output_tensors]() -> ::mediapipe::Status { [this, &output_tensors]() -> ::mediapipe::Status {
output_tensors->resize(1); output_tensors->resize(1);
{ {
// Thuan (2020-04-14: Fix bug output video not stable)
// - TODO Check buffer of tensor is not reference internal memory in GPU
GpuTensor& tensor = output_tensors->at(0); GpuTensor& tensor = output_tensors->at(0);
MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>( MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
gpu_data_out_->elements, &tensor)); gpu_data_out_->elements, &tensor));

View File

@ -30,6 +30,7 @@
#include "mediapipe/framework/port/status_matchers.h" // NOLINT #include "mediapipe/framework/port/status_matchers.h" // NOLINT
#include "mediapipe/framework/tool/validate_type.h" #include "mediapipe/framework/tool/validate_type.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "mediapipe/calculators/tflite/util.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -114,11 +115,11 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) {
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
// Get and process results. // Get and process results.
const std::vector<TfLiteTensor>& tensor_vec = const std::vector<TfLiteTensorContainer>& tensor_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>(); output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
EXPECT_EQ(1, tensor_vec.size()); EXPECT_EQ(1, tensor_vec.size());
const TfLiteTensor* tensor = &tensor_vec[0]; const TfLiteTensorr* tensor = &(tensor_vec[0].getTensor());
EXPECT_EQ(kTfLiteFloat32, tensor->type); EXPECT_EQ(kTfLiteFloat32, tensor->type);
// Verify that the data is correct. // Verify that the data is correct.
@ -175,11 +176,11 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) {
EXPECT_EQ(1, output_packets.size()); EXPECT_EQ(1, output_packets.size());
// Get and process results. // Get and process results.
const std::vector<TfLiteTensor>& tensor_vec = const std::vector<TfLiteTensorContainer>& tensor_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>(); output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
EXPECT_EQ(1, tensor_vec.size()); EXPECT_EQ(1, tensor_vec.size());
const TfLiteTensor* tensor = &tensor_vec[0]; const TfLiteTensor* tensor = &(tensor_vec[0].getTensor());
EXPECT_EQ(kTfLiteFloat32, tensor->type); EXPECT_EQ(kTfLiteFloat32, tensor->type);
// Verify that the data is correct. // Verify that the data is correct.

View File

@ -19,6 +19,7 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
@ -232,15 +233,15 @@ class TfLiteInferenceCalculator : public CalculatorBase {
::mediapipe::Status LoadDelegate(CalculatorContext* cc); ::mediapipe::Status LoadDelegate(CalculatorContext* cc);
::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); ::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
::mediapipe::Status ProcessInputsCpu( ::mediapipe::Status ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu); CalculatorContext* cc, std::vector<TfLiteTensorContainer>* output_tensors_cpu);
::mediapipe::Status ProcessOutputsCpu( ::mediapipe::Status ProcessOutputsCpu(
CalculatorContext* cc, CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu); std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu);
::mediapipe::Status ProcessInputsGpu( ::mediapipe::Status ProcessInputsGpu(
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu); CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
::mediapipe::Status ProcessOutputsGpu( ::mediapipe::Status ProcessOutputsGpu(
CalculatorContext* cc, CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu, std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu); std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
::mediapipe::Status RunInContextIfNeeded( ::mediapipe::Status RunInContextIfNeeded(
@ -319,9 +320,9 @@ bool ShouldUseGpu(CC* cc) {
<< "Either model as side packet or model path in options is required."; << "Either model as side packet or model path in options is required.";
if (cc->Inputs().HasTag(kTensorsTag)) if (cc->Inputs().HasTag(kTensorsTag))
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
if (cc->Outputs().HasTag(kTensorsTag)) if (cc->Outputs().HasTag(kTensorsTag))
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
if (cc->Inputs().HasTag(kTensorsGpuTag)) if (cc->Inputs().HasTag(kTensorsGpuTag))
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>(); cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
@ -413,7 +414,7 @@ bool ShouldUseGpu(CC* cc) {
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status { return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
// 0. Declare outputs // 0. Declare outputs
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>(); auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensorContainer>>();
// 1. Receive pre-processed tensor inputs. // 1. Receive pre-processed tensor inputs.
if (gpu_input_) { if (gpu_input_) {
@ -487,16 +488,16 @@ bool ShouldUseGpu(CC* cc) {
// Calculator Auxiliary Section // Calculator Auxiliary Section
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu( ::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) { CalculatorContext* cc, std::vector<TfLiteTensorContainer>* output_tensors_cpu) {
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
// Read CPU input into tensors. // Read CPU input into tensors.
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensorContainer>>();
RET_CHECK_GT(input_tensors.size(), 0); RET_CHECK_GT(input_tensors.size(), 0);
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteTensor* input_tensor = &input_tensors[i]; const TfLiteTensor* input_tensor = &(input_tensors[i].getTensor());
RET_CHECK(input_tensor->data.raw); RET_CHECK(input_tensor->data.raw);
if (use_quantized_tensors_) { if (use_quantized_tensors_) {
const uint8* input_tensor_buffer = input_tensor->data.uint8; const uint8* input_tensor_buffer = input_tensor->data.uint8;
@ -588,12 +589,16 @@ bool ShouldUseGpu(CC* cc) {
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu( ::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
CalculatorContext* cc, CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) { std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu) {
// Output result tensors (CPU). // Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs(); const auto& tensor_indexes = interpreter_->outputs();
for (int i = 0; i < tensor_indexes.size(); ++i) { for (int i = 0; i < tensor_indexes.size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
output_tensors_cpu->emplace_back(*tensor); // Thuan (2020-04-14: Fix bug output video not stable): Using TfLiteTensorContainer for make new memory for data in tensor
TfLiteTensorContainer tensor_out(*tensor);
VLOG(2) << "INFERENCE interpreter_=" << interpreter_.get() << ";InputTimestamp=" << cc->InputTimestamp()
<< " has output tensor data address=" << tensor->data.f ;
output_tensors_cpu->emplace_back(tensor_out);
} }
cc->Outputs() cc->Outputs()
.Tag(kTensorsTag) .Tag(kTensorsTag)
@ -604,7 +609,7 @@ bool ShouldUseGpu(CC* cc) {
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu( ::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
CalculatorContext* cc, CalculatorContext* cc,
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu, std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu,
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) { std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
if (use_advanced_gpu_api_) { if (use_advanced_gpu_api_) {
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
@ -621,7 +626,8 @@ bool ShouldUseGpu(CC* cc) {
std::vector<float> gpu_data(tensor->bytes / sizeof(float)); std::vector<float> gpu_data(tensor->bytes / sizeof(float));
MP_RETURN_IF_ERROR(gpu_data_out_[i]->buffer.Read( MP_RETURN_IF_ERROR(gpu_data_out_[i]->buffer.Read(
absl::MakeSpan(tensor->data.f, tensor->bytes))); absl::MakeSpan(tensor->data.f, tensor->bytes)));
output_tensors_cpu->emplace_back(*tensor); TfLiteTensorContainer tensor_out(*tensor);
output_tensors_cpu->emplace_back(tensor_out);
} }
// Output result tensors (CPU). // Output result tensors (CPU).
cc->Outputs() cc->Outputs()

View File

@ -85,11 +85,11 @@ void DoSmokeTest(const std::string& graph_proto) {
ASSERT_EQ(1, output_packets.size()); ASSERT_EQ(1, output_packets.size());
// Get and process results. // Get and process results.
const std::vector<TfLiteTensor>& result_vec = const std::vector<TfLiteTensorContainer> & result_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>(); output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
ASSERT_EQ(1, result_vec.size()); ASSERT_EQ(1, result_vec.size());
const TfLiteTensor* result = &result_vec[0]; const TfLiteTensor* result = &(result_vec[0].getTensor());
float* result_buffer = result->data.f; float* result_buffer = result->data.f;
ASSERT_NE(result_buffer, nullptr); ASSERT_NE(result_buffer, nullptr);
for (int i = 0; i < width * height * channels - 1; i++) { for (int i = 0; i < width * height * channels - 1; i++) {

View File

@ -18,6 +18,7 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
@ -46,6 +47,9 @@
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
// Thuan (2020-04-14: Fix bug output video not stable)
//TODO: If the detection has mask and other data is array or pointer, then we consider not share reference as output it
namespace { namespace {
constexpr int kNumInputTensorsWithAnchors = 3; constexpr int kNumInputTensorsWithAnchors = 3;
constexpr int kNumCoordsPerBox = 4; constexpr int kNumCoordsPerBox = 4;
@ -197,7 +201,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
bool use_gpu = false; bool use_gpu = false;
if (cc->Inputs().HasTag(kTensorsTag)) { if (cc->Inputs().HasTag(kTensorsTag)) {
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
} }
if (cc->Inputs().HasTag(kTensorsGpuTag)) { if (cc->Inputs().HasTag(kTensorsGpuTag)) {
@ -278,14 +282,15 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) { CalculatorContext* cc, std::vector<Detection>* output_detections) {
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>(); cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensorContainer>>();
if (input_tensors.size() == 2 || if (input_tensors.size() == 2 ||
input_tensors.size() == kNumInputTensorsWithAnchors) { input_tensors.size() == kNumInputTensorsWithAnchors) {
// Postprocessing on CPU for model without postprocessing op. E.g. output // Postprocessing on CPU for model without postprocessing op. E.g. output
// raw score tensor and box tensor. Anchor decoding will be handled below. // raw score tensor and box tensor. Anchor decoding will be handled below.
const TfLiteTensor* raw_box_tensor = &input_tensors[0]; // Thuan (2020-04-14: Fix bug output video not stable)
const TfLiteTensor* raw_score_tensor = &input_tensors[1]; const TfLiteTensor* raw_box_tensor = &(input_tensors[0].getTensor());
const TfLiteTensor* raw_score_tensor = &(input_tensors[1].getTensor());
// TODO: Add flexible input tensor size handling. // TODO: Add flexible input tensor size handling.
CHECK_EQ(raw_box_tensor->dims->size, 3); CHECK_EQ(raw_box_tensor->dims->size, 3);
@ -299,10 +304,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
const float* raw_boxes = raw_box_tensor->data.f; const float* raw_boxes = raw_box_tensor->data.f;
const float* raw_scores = raw_score_tensor->data.f; const float* raw_scores = raw_score_tensor->data.f;
VLOG(2) << "TENSOR TO DETECTION;InputTimestamp=" << cc->InputTimestamp() << "num_boxes_=" << num_boxes_
<< " has input tensor boxes data address=" << raw_boxes << "; input tensor scores data address=" << raw_scores ;
// TODO: Support other options to load anchors. // TODO: Support other options to load anchors.
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
const TfLiteTensor* anchor_tensor = &input_tensors[2]; VLOG(1) << "Execute the anchor TENSOR";
const TfLiteTensor* anchor_tensor = &(input_tensors[2].getTensor());
CHECK_EQ(anchor_tensor->dims->size, 2); CHECK_EQ(anchor_tensor->dims->size, 2);
CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_); CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_);
CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox); CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
@ -360,10 +371,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
// non-maximum suppression) within the model. // non-maximum suppression) within the model.
RET_CHECK_EQ(input_tensors.size(), 4); RET_CHECK_EQ(input_tensors.size(), 4);
const TfLiteTensor* detection_boxes_tensor = &input_tensors[0]; const TfLiteTensor* detection_boxes_tensor = &(input_tensors[0].getTensor());
const TfLiteTensor* detection_classes_tensor = &input_tensors[1]; const TfLiteTensor* detection_classes_tensor = &(input_tensors[1].getTensor());
const TfLiteTensor* detection_scores_tensor = &input_tensors[2]; const TfLiteTensor* detection_scores_tensor = &(input_tensors[2].getTensor());
const TfLiteTensor* num_boxes_tensor = &input_tensors[3]; const TfLiteTensor* num_boxes_tensor = &(input_tensors[3].getTensor());
RET_CHECK_EQ(num_boxes_tensor->dims->size, 1); RET_CHECK_EQ(num_boxes_tensor->dims->size, 1);
RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1); RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1);
const float* num_boxes = num_boxes_tensor->data.f; const float* num_boxes = num_boxes_tensor->data.f;

View File

@ -0,0 +1,88 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#include "tensorflow/lite/interpreter.h"
#define RET_CHECK_CALL(call) \
do { \
const auto status = (call); \
if (ABSL_PREDICT_FALSE(!status.ok())) \
return ::mediapipe::InternalError(status.message()); \
} while (0);
namespace mediapipe {
class TfLiteTensorContainer {
private:
TfLiteTensor tensor_;
std::unique_ptr<TfLiteIntArray> dims_;
std::unique_ptr<char[]> data_;
//Free internal memory
void FreeTensor() {
tensor_.dims = 0;
tensor_.data.raw = 0;
dims_.reset();
data_.reset();
}
//Copy data from source tensor
void CopyTensor(const TfLiteTensor& tensor) {
//Free internal memory for copy new data
FreeTensor();
//Copy data from source to internal member
dims_.reset(TfLiteIntArrayCreate((tensor.dims)->size));
memcpy(dims_->data, (tensor.dims)->data, sizeof(int)*((tensor.dims)->size));
data_ = absl::make_unique<char[]>(tensor.bytes);
memcpy(data_.get(), tensor.data.raw, tensor.bytes);
memcpy(&tensor_, &tensor, sizeof(TfLiteTensor));
tensor_.dims = dims_.get();
tensor_.data.raw = data_.get();
}
public:
TfLiteTensorContainer(const TfLiteTensor& tensor) {
CopyTensor(tensor);
}
//Copy constructor
TfLiteTensorContainer(const TfLiteTensorContainer& tensor_ctn) {
CopyTensor(tensor_ctn.getTensor());
}
//Destructor
~TfLiteTensorContainer() {
FreeTensor();
}
// Get tensor
const TfLiteTensor& getTensor() const {
return tensor_;
}
//Assign operator
TfLiteTensorContainer & operator= ( TfLiteTensorContainer tensor_ctn){
CopyTensor(tensor_ctn.getTensor());
return *this;
}
TfLiteTensorContainer & operator= ( const TfLiteTensorContainer & tensor_ctn) {
CopyTensor(tensor_ctn.getTensor());
return *this;
}
};
}
#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_