Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
Pierre Fenoll 2021-09-06 12:48:35 +02:00
parent 789e61ba11
commit bab48969f6
2 changed files with 179 additions and 0 deletions

View File

@ -134,3 +134,17 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_test(
name = "pytorch_inference_calculator_test",
srcs = ["pytorch_inference_calculator_test.cc"],
deps = [
":pytorch_inference_calculator",
":pytorch_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)

View File

@ -0,0 +1,165 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #include <memory>
// #include <string>
// #include <vector>
// #include "absl/strings/str_replace.h"
// #include "absl/strings/string_view.h"
#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
// #include "mediapipe/framework/calculator_runner.h"
// #include "mediapipe/framework/deps/file_path.h"
// #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
// #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
// #include "mediapipe/framework/tool/validate_type.h"
namespace mediapipe {
using ::tflite::Interpreter;
void DoSmokeTest(const std::string& graph_proto) {
const int width = 8;
const int height = 8;
const int channels = 3;
// Prepare input tensor.
std::unique_ptr<Interpreter> interpreter(new Interpreter);
ASSERT_NE(interpreter, nullptr);
interpreter->AddTensors(1);
interpreter->SetInputs({0});
interpreter->SetOutputs({0});
interpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
TfLiteQuantization());
int t = interpreter->inputs()[0];
TfLiteTensor* tensor = interpreter->tensor(t);
interpreter->ResizeInputTensor(t, {width, height, channels});
interpreter->AllocateTensors();
float* tensor_buffer = tensor->data.f;
ASSERT_NE(tensor_buffer, nullptr);
for (int i = 0; i < width * height * channels - 1; i++) {
tensor_buffer[i] = 1;
}
auto input_vec = absl::make_unique<std::vector<TfLiteTensor>>();
input_vec->emplace_back(*tensor);
// Prepare single calculator graph to and wait for packets.
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun({}));
// Push the tensor into the graph.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec.release()).At(Timestamp(0))));
// Wait until the calculator done processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, output_packets.size());
// Get and process results.
const std::vector<TfLiteTensor>& result_vec =
output_packets[0].Get<std::vector<TfLiteTensor>>();
ASSERT_EQ(1, result_vec.size());
const TfLiteTensor* result = &result_vec[0];
float* result_buffer = result->data.f;
ASSERT_NE(result_buffer, nullptr);
for (int i = 0; i < width * height * channels - 1; i++) {
ASSERT_EQ(3, result_buffer[i]);
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Tests a simple add model that adds an input tensor to itself.
TEST(TfLiteInferenceCalculatorTest, SmokeTest) {
std::string graph_proto = R"(
input_stream: "tensor_in"
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:tensor_in"
output_stream: "TENSORS:tensor_out"
options {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
$delegate
}
}
}
)";
DoSmokeTest(
/*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto, {{"$delegate", "delegate { xnnpack {} }"}}));
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto,
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
}
TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
std::string graph_proto = R"(
input_stream: "tensor_in"
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet { string_value: "mediapipe/calculators/tflite/testdata/add.bin" }
}
}
}
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
}
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
output_side_packet: "MODEL:model"
}
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:tensor_in"
output_stream: "TENSORS:tensor_out"
input_side_packet: "MODEL:model"
options {
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
use_gpu: false
}
}
}
)";
DoSmokeTest(graph_proto);
}
} // namespace mediapipe