diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 6b63403f7..c52e2e283 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -199,6 +199,25 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "bert_preprocessor_calculator_test", + srcs = ["bert_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":bert_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + mediapipe_proto_library( name = "regex_preprocessor_calculator_proto", srcs = ["regex_preprocessor_calculator.proto"], diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc new file mode 100644 index 000000000..b497a6168 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc @@ -0,0 +1,154 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kBertMaxSeqLen = 128; +constexpr absl::string_view kTestModelPath = + "mediapipe/tasks/testdata/text/bert_text_classifier.tflite"; + +absl::StatusOr>> RunBertPreprocessorCalculator( + absl::string_view text, absl::string_view model_path) { + auto graph_config = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "BertPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + options { + [mediapipe.BertPreprocessorCalculatorOptions.ext] { + bert_max_seq_len: $0 + } + } + } + )", + kBertMaxSeqLen)); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data()); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != kNumInputTensorsForBert) { + return absl::InvalidArgumentError( + absl::Substitute("tensor_vec has size $0, expected $1", + tensor_vec.size(), kNumInputTensorsForBert)); + } + + std::vector> results; + for (int i = 0; i < kNumInputTensorsForBert; i++) { + const Tensor& tensor = tensor_vec[i]; + if (tensor.element_type() != Tensor::ElementType::kInt32) { + return absl::InvalidArgumentError("Expected tensor element type kInt32"); + } + auto* buffer = tensor.GetCpuReadView().buffer(); + std::vector buffer_view(buffer, buffer + kBertMaxSeqLen); + results.push_back(buffer_view); + } + MP_RETURN_IF_ERROR(graph.CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph.WaitUntilDone()); + return results; +} + +TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) { + std::vector> expected_result = { + {101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}}; + // segment_ids + expected_result.push_back(std::vector(kBertMaxSeqLen, 0)); + // input_masks + expected_result.push_back(std::vector(expected_result[0].size(), 1)); + expected_result[2].resize(kBertMaxSeqLen); + // padding input_ids + expected_result[0].resize(kBertMaxSeqLen); + + MP_ASSERT_OK_AND_ASSIGN( + std::vector> processed_tensor_values, + RunBertPreprocessorCalculator( + "it's a charming and often affecting journey", kTestModelPath)); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +TEST(BertPreprocessorCalculatorTest, LongInput) { + std::stringstream long_input; + long_input + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kBertMaxSeqLen; ++i) { + long_input << " long"; + } + long_input << " movie review"; + std::vector> expected_result = { + {101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023, + 2003, 1037}}; + // "long" id + expected_result[0].resize(kBertMaxSeqLen - 1, 2146); + // "[SEP]" id + expected_result[0].push_back(102); + // segment_ids + expected_result.push_back(std::vector(kBertMaxSeqLen, 0)); + // input_masks + expected_result.push_back(std::vector(kBertMaxSeqLen, 1)); + + MP_ASSERT_OK_AND_ASSIGN( + std::vector> processed_tensor_values, + RunBertPreprocessorCalculator(long_input.str(), kTestModelPath)); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +} // namespace +} // namespace mediapipe