wip
Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
parent
789e61ba11
commit
bab48969f6
|
@ -134,3 +134,17 @@ cc_test(
|
|||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "pytorch_inference_calculator_test",
|
||||
srcs = ["pytorch_inference_calculator_test.cc"],
|
||||
deps = [
|
||||
":pytorch_inference_calculator",
|
||||
":pytorch_inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// #include <memory>
|
||||
// #include <string>
|
||||
// #include <vector>
|
||||
|
||||
// #include "absl/strings/str_replace.h"
|
||||
// #include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
// #include "mediapipe/framework/calculator_runner.h"
|
||||
// #include "mediapipe/framework/deps/file_path.h"
|
||||
// #include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
// #include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
// #include "mediapipe/framework/tool/validate_type.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::tflite::Interpreter;
|
||||
|
||||
void DoSmokeTest(const std::string& graph_proto) {
|
||||
const int width = 8;
|
||||
const int height = 8;
|
||||
const int channels = 3;
|
||||
|
||||
// Prepare input tensor.
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
|
||||
interpreter->AddTensors(1);
|
||||
interpreter->SetInputs({0});
|
||||
interpreter->SetOutputs({0});
|
||||
interpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
|
||||
TfLiteQuantization());
|
||||
int t = interpreter->inputs()[0];
|
||||
TfLiteTensor* tensor = interpreter->tensor(t);
|
||||
interpreter->ResizeInputTensor(t, {width, height, channels});
|
||||
interpreter->AllocateTensors();
|
||||
|
||||
float* tensor_buffer = tensor->data.f;
|
||||
ASSERT_NE(tensor_buffer, nullptr);
|
||||
for (int i = 0; i < width * height * channels - 1; i++) {
|
||||
tensor_buffer[i] = 1;
|
||||
}
|
||||
|
||||
auto input_vec = absl::make_unique<std::vector<TfLiteTensor>>();
|
||||
input_vec->emplace_back(*tensor);
|
||||
|
||||
// Prepare single calculator graph to and wait for packets.
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
|
||||
CalculatorGraph graph(graph_config);
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Push the tensor into the graph.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator done processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<TfLiteTensor>& result_vec =
|
||||
output_packets[0].Get<std::vector<TfLiteTensor>>();
|
||||
ASSERT_EQ(1, result_vec.size());
|
||||
|
||||
const TfLiteTensor* result = &result_vec[0];
|
||||
float* result_buffer = result->data.f;
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
for (int i = 0; i < width * height * channels - 1; i++) {
|
||||
ASSERT_EQ(3, result_buffer[i]);
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Tests a simple add model that adds an input tensor to itself.
|
||||
TEST(TfLiteInferenceCalculatorTest, SmokeTest) {
|
||||
std::string graph_proto = R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "TfLiteInferenceCalculator"
|
||||
input_stream: "TENSORS:tensor_in"
|
||||
output_stream: "TENSORS:tensor_out"
|
||||
options {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
model_path: "mediapipe/calculators/tflite/testdata/add.bin"
|
||||
$delegate
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
DoSmokeTest(
|
||||
/*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}}));
|
||||
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
|
||||
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||
graph_proto, {{"$delegate", "delegate { xnnpack {} }"}}));
|
||||
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||
graph_proto,
|
||||
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
|
||||
}
|
||||
|
||||
TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
|
||||
std::string graph_proto = R"(
|
||||
input_stream: "tensor_in"
|
||||
|
||||
node {
|
||||
calculator: "ConstantSidePacketCalculator"
|
||||
output_side_packet: "PACKET:model_path"
|
||||
options: {
|
||||
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
|
||||
packet { string_value: "mediapipe/calculators/tflite/testdata/add.bin" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "LocalFileContentsCalculator"
|
||||
input_side_packet: "FILE_PATH:model_path"
|
||||
output_side_packet: "CONTENTS:model_blob"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TfLiteModelCalculator"
|
||||
input_side_packet: "MODEL_BLOB:model_blob"
|
||||
output_side_packet: "MODEL:model"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "TfLiteInferenceCalculator"
|
||||
input_stream: "TENSORS:tensor_in"
|
||||
output_stream: "TENSORS:tensor_out"
|
||||
input_side_packet: "MODEL:model"
|
||||
options {
|
||||
[mediapipe.TfLiteInferenceCalculatorOptions.ext] {
|
||||
use_gpu: false
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
DoSmokeTest(graph_proto);
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user