Fix bug: override output tensor data buffer in tflite_converter_calculator and tflite_inference_calculator and tflite_tensors_to_detections_calculator. This bug cause output wrong object detection in muti-threading. It make output video unstable (different output video with same input)
This commit is contained in:
parent
1db91b550a
commit
328905ec6b
|
@ -196,6 +196,12 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "util",
|
||||
hdrs = ["util.h"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "gpu_inference_disabled",
|
||||
match_any = [
|
||||
|
@ -222,6 +228,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":util",
|
||||
":tflite_inference_calculator_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -288,6 +295,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite:config",
|
||||
":util",
|
||||
":tflite_converter_calculator_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -407,6 +415,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":util",
|
||||
"//mediapipe/util/tflite:config",
|
||||
":tflite_tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/util.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
|
@ -207,7 +208,7 @@ bool ShouldUseGpu(CC* cc) {
|
|||
#endif // MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
if (cc->Outputs().HasTag(kTensorsTag)) {
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
|
||||
}
|
||||
if (cc->Outputs().HasTag(kTensorsGpuTag)) {
|
||||
cc->Outputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
|
@ -368,8 +369,9 @@ bool ShouldUseGpu(CC* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
output_tensors->emplace_back(*tensor);
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensorContainer>>();
|
||||
TfLiteTensorContainer tensor_out(*tensor);
|
||||
output_tensors->emplace_back(tensor_out);
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
|
@ -400,8 +402,9 @@ bool ShouldUseGpu(CC* cc) {
|
|||
|
||||
MP_RETURN_IF_ERROR(CopyMatrixToTensor(matrix, tensor_ptr));
|
||||
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
output_tensors->emplace_back(*tensor);
|
||||
auto output_tensors = absl::make_unique<std::vector<TfLiteTensorContainer>>();
|
||||
TfLiteTensorContainer tensor_out(*tensor);
|
||||
output_tensors->emplace_back(tensor_out);
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
|
@ -439,6 +442,8 @@ bool ShouldUseGpu(CC* cc) {
|
|||
[this, &output_tensors]() -> ::mediapipe::Status {
|
||||
output_tensors->resize(1);
|
||||
{
|
||||
// Thuan (2020-04-14: Fix bug output video not stable)
|
||||
// - TODO Check buffer of tensor is not reference internal memory in GPU
|
||||
GpuTensor& tensor = output_tensors->at(0);
|
||||
MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer<float>(
|
||||
gpu_data_out_->elements, &tensor));
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
#include "mediapipe/framework/tool/validate_type.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "mediapipe/calculators/tflite/util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
@ -114,11 +115,11 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixColMajor) {
|
|||
EXPECT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
const std::vector<TfLiteTensorContainer>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
|
||||
EXPECT_EQ(1, tensor_vec.size());
|
||||
|
||||
const TfLiteTensor* tensor = &tensor_vec[0];
|
||||
const TfLiteTensorr* tensor = &(tensor_vec[0].getTensor());
|
||||
EXPECT_EQ(kTfLiteFloat32, tensor->type);
|
||||
|
||||
// Verify that the data is correct.
|
||||
|
@ -175,11 +176,11 @@ TEST_F(TfLiteConverterCalculatorTest, RandomMatrixRowMajor) {
|
|||
EXPECT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
const std::vector<TfLiteTensorContainer>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
|
||||
EXPECT_EQ(1, tensor_vec.size());
|
||||
|
||||
const TfLiteTensor* tensor = &tensor_vec[0];
|
||||
const TfLiteTensor* tensor = &(tensor_vec[0].getTensor());
|
||||
EXPECT_EQ(kTfLiteFloat32, tensor->type);
|
||||
|
||||
// Verify that the data is correct.
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/util.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
|
@ -232,15 +233,15 @@ class TfLiteInferenceCalculator : public CalculatorBase {
|
|||
::mediapipe::Status LoadDelegate(CalculatorContext* cc);
|
||||
::mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
::mediapipe::Status ProcessInputsCpu(
|
||||
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu);
|
||||
CalculatorContext* cc, std::vector<TfLiteTensorContainer>* output_tensors_cpu);
|
||||
::mediapipe::Status ProcessOutputsCpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu);
|
||||
std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu);
|
||||
::mediapipe::Status ProcessInputsGpu(
|
||||
CalculatorContext* cc, std::vector<GpuTensor>* output_tensors_gpu);
|
||||
::mediapipe::Status ProcessOutputsGpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu);
|
||||
|
||||
::mediapipe::Status RunInContextIfNeeded(
|
||||
|
@ -319,9 +320,9 @@ bool ShouldUseGpu(CC* cc) {
|
|||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsTag))
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
|
||||
if (cc->Outputs().HasTag(kTensorsTag))
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
cc->Outputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag))
|
||||
cc->Inputs().Tag(kTensorsGpuTag).Set<std::vector<GpuTensor>>();
|
||||
|
@ -413,7 +414,7 @@ bool ShouldUseGpu(CC* cc) {
|
|||
return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status {
|
||||
// 0. Declare outputs
|
||||
auto output_tensors_gpu = absl::make_unique<std::vector<GpuTensor>>();
|
||||
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
auto output_tensors_cpu = absl::make_unique<std::vector<TfLiteTensorContainer>>();
|
||||
|
||||
// 1. Receive pre-processed tensor inputs.
|
||||
if (gpu_input_) {
|
||||
|
@ -487,16 +488,16 @@ bool ShouldUseGpu(CC* cc) {
|
|||
// Calculator Auxiliary Section
|
||||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu(
|
||||
CalculatorContext* cc, std::vector<TfLiteTensor>* output_tensors_cpu) {
|
||||
CalculatorContext* cc, std::vector<TfLiteTensorContainer>* output_tensors_cpu) {
|
||||
if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
// Read CPU input into tensors.
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensorContainer>>();
|
||||
RET_CHECK_GT(input_tensors.size(), 0);
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
const TfLiteTensor* input_tensor = &input_tensors[i];
|
||||
const TfLiteTensor* input_tensor = &(input_tensors[i].getTensor());
|
||||
RET_CHECK(input_tensor->data.raw);
|
||||
if (use_quantized_tensors_) {
|
||||
const uint8* input_tensor_buffer = input_tensor->data.uint8;
|
||||
|
@ -588,12 +589,16 @@ bool ShouldUseGpu(CC* cc) {
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu) {
|
||||
std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu) {
|
||||
// Output result tensors (CPU).
|
||||
const auto& tensor_indexes = interpreter_->outputs();
|
||||
for (int i = 0; i < tensor_indexes.size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]);
|
||||
output_tensors_cpu->emplace_back(*tensor);
|
||||
// Thuan (2020-04-14: Fix bug output video not stable): Using TfLiteTensorContainer for make new memory for data in tensor
|
||||
TfLiteTensorContainer tensor_out(*tensor);
|
||||
VLOG(2) << "INFERENCE interpreter_=" << interpreter_.get() << ";InputTimestamp=" << cc->InputTimestamp()
|
||||
<< " has output tensor data address=" << tensor->data.f ;
|
||||
output_tensors_cpu->emplace_back(tensor_out);
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
|
@ -604,7 +609,7 @@ bool ShouldUseGpu(CC* cc) {
|
|||
|
||||
::mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu(
|
||||
CalculatorContext* cc,
|
||||
std::unique_ptr<std::vector<TfLiteTensor>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<TfLiteTensorContainer>> output_tensors_cpu,
|
||||
std::unique_ptr<std::vector<GpuTensor>> output_tensors_gpu) {
|
||||
if (use_advanced_gpu_api_) {
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
@ -621,7 +626,8 @@ bool ShouldUseGpu(CC* cc) {
|
|||
std::vector<float> gpu_data(tensor->bytes / sizeof(float));
|
||||
MP_RETURN_IF_ERROR(gpu_data_out_[i]->buffer.Read(
|
||||
absl::MakeSpan(tensor->data.f, tensor->bytes)));
|
||||
output_tensors_cpu->emplace_back(*tensor);
|
||||
TfLiteTensorContainer tensor_out(*tensor);
|
||||
output_tensors_cpu->emplace_back(tensor_out);
|
||||
}
|
||||
// Output result tensors (CPU).
|
||||
cc->Outputs()
|
||||
|
|
|
@ -85,11 +85,11 @@ void DoSmokeTest(const std::string& graph_proto) {
|
|||
ASSERT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& result_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
const std::vector<TfLiteTensorContainer> & result_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensorContainer>>();
|
||||
ASSERT_EQ(1, result_vec.size());
|
||||
|
||||
const TfLiteTensor* result = &result_vec[0];
|
||||
const TfLiteTensor* result = &(result_vec[0].getTensor());
|
||||
float* result_buffer = result->data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
for (int i = 0; i < width * height * channels - 1; i++) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/util.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
|
@ -46,6 +47,9 @@
|
|||
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
|
||||
#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE
|
||||
|
||||
|
||||
// Thuan (2020-04-14: Fix bug output video not stable)
|
||||
//TODO: If the detection has mask and other data is array or pointer, then we consider not share reference as output it
|
||||
namespace {
|
||||
constexpr int kNumInputTensorsWithAnchors = 3;
|
||||
constexpr int kNumCoordsPerBox = 4;
|
||||
|
@ -197,7 +201,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
bool use_gpu = false;
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsTag)) {
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensor>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Set<std::vector<TfLiteTensorContainer>>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsGpuTag)) {
|
||||
|
@ -278,14 +282,15 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU(
|
||||
CalculatorContext* cc, std::vector<Detection>* output_detections) {
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensor>>();
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<TfLiteTensorContainer>>();
|
||||
|
||||
if (input_tensors.size() == 2 ||
|
||||
input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
// Postprocessing on CPU for model without postprocessing op. E.g. output
|
||||
// raw score tensor and box tensor. Anchor decoding will be handled below.
|
||||
const TfLiteTensor* raw_box_tensor = &input_tensors[0];
|
||||
const TfLiteTensor* raw_score_tensor = &input_tensors[1];
|
||||
// Thuan (2020-04-14: Fix bug output video not stable)
|
||||
const TfLiteTensor* raw_box_tensor = &(input_tensors[0].getTensor());
|
||||
const TfLiteTensor* raw_score_tensor = &(input_tensors[1].getTensor());
|
||||
|
||||
// TODO: Add flexible input tensor size handling.
|
||||
CHECK_EQ(raw_box_tensor->dims->size, 3);
|
||||
|
@ -299,10 +304,16 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
const float* raw_boxes = raw_box_tensor->data.f;
|
||||
const float* raw_scores = raw_score_tensor->data.f;
|
||||
|
||||
VLOG(2) << "TENSOR TO DETECTION;InputTimestamp=" << cc->InputTimestamp() << "num_boxes_=" << num_boxes_
|
||||
<< " has input tensor boxes data address=" << raw_boxes << "; input tensor scores data address=" << raw_scores ;
|
||||
|
||||
|
||||
// TODO: Support other options to load anchors.
|
||||
if (!anchors_init_) {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
const TfLiteTensor* anchor_tensor = &input_tensors[2];
|
||||
VLOG(1) << "Execute the anchor TENSOR";
|
||||
|
||||
const TfLiteTensor* anchor_tensor = &(input_tensors[2].getTensor());
|
||||
CHECK_EQ(anchor_tensor->dims->size, 2);
|
||||
CHECK_EQ(anchor_tensor->dims->data[0], num_boxes_);
|
||||
CHECK_EQ(anchor_tensor->dims->data[1], kNumCoordsPerBox);
|
||||
|
@ -360,10 +371,10 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
|
|||
// non-maximum suppression) within the model.
|
||||
RET_CHECK_EQ(input_tensors.size(), 4);
|
||||
|
||||
const TfLiteTensor* detection_boxes_tensor = &input_tensors[0];
|
||||
const TfLiteTensor* detection_classes_tensor = &input_tensors[1];
|
||||
const TfLiteTensor* detection_scores_tensor = &input_tensors[2];
|
||||
const TfLiteTensor* num_boxes_tensor = &input_tensors[3];
|
||||
const TfLiteTensor* detection_boxes_tensor = &(input_tensors[0].getTensor());
|
||||
const TfLiteTensor* detection_classes_tensor = &(input_tensors[1].getTensor());
|
||||
const TfLiteTensor* detection_scores_tensor = &(input_tensors[2].getTensor());
|
||||
const TfLiteTensor* num_boxes_tensor = &(input_tensors[3].getTensor());
|
||||
RET_CHECK_EQ(num_boxes_tensor->dims->size, 1);
|
||||
RET_CHECK_EQ(num_boxes_tensor->dims->data[0], 1);
|
||||
const float* num_boxes = num_boxes_tensor->data.f;
|
||||
|
|
88
mediapipe/calculators/tflite/util.h
Normal file
88
mediapipe/calculators/tflite/util.h
Normal file
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
|
||||
#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
|
||||
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
#define RET_CHECK_CALL(call) \
|
||||
do { \
|
||||
const auto status = (call); \
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) \
|
||||
return ::mediapipe::InternalError(status.message()); \
|
||||
} while (0);
|
||||
|
||||
namespace mediapipe {
|
||||
class TfLiteTensorContainer {
|
||||
private:
|
||||
TfLiteTensor tensor_;
|
||||
std::unique_ptr<TfLiteIntArray> dims_;
|
||||
std::unique_ptr<char[]> data_;
|
||||
|
||||
//Free internal memory
|
||||
void FreeTensor() {
|
||||
tensor_.dims = 0;
|
||||
tensor_.data.raw = 0;
|
||||
dims_.reset();
|
||||
data_.reset();
|
||||
}
|
||||
//Copy data from source tensor
|
||||
void CopyTensor(const TfLiteTensor& tensor) {
|
||||
//Free internal memory for copy new data
|
||||
FreeTensor();
|
||||
|
||||
//Copy data from source to internal member
|
||||
dims_.reset(TfLiteIntArrayCreate((tensor.dims)->size));
|
||||
memcpy(dims_->data, (tensor.dims)->data, sizeof(int)*((tensor.dims)->size));
|
||||
data_ = absl::make_unique<char[]>(tensor.bytes);
|
||||
memcpy(data_.get(), tensor.data.raw, tensor.bytes);
|
||||
memcpy(&tensor_, &tensor, sizeof(TfLiteTensor));
|
||||
tensor_.dims = dims_.get();
|
||||
tensor_.data.raw = data_.get();
|
||||
}
|
||||
public:
|
||||
TfLiteTensorContainer(const TfLiteTensor& tensor) {
|
||||
CopyTensor(tensor);
|
||||
}
|
||||
//Copy constructor
|
||||
TfLiteTensorContainer(const TfLiteTensorContainer& tensor_ctn) {
|
||||
CopyTensor(tensor_ctn.getTensor());
|
||||
}
|
||||
|
||||
//Destructor
|
||||
~TfLiteTensorContainer() {
|
||||
FreeTensor();
|
||||
}
|
||||
|
||||
// Get tensor
|
||||
const TfLiteTensor& getTensor() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
//Assign operator
|
||||
TfLiteTensorContainer & operator= ( TfLiteTensorContainer tensor_ctn){
|
||||
CopyTensor(tensor_ctn.getTensor());
|
||||
|
||||
return *this;
|
||||
}
|
||||
TfLiteTensorContainer & operator= ( const TfLiteTensorContainer & tensor_ctn) {
|
||||
CopyTensor(tensor_ctn.getTensor());
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
|
Loading…
Reference in New Issue
Block a user