Open-sources UniversalSentenceEncoderPreprocessorCalculator.

PiperOrigin-RevId: 482222697
This commit is contained in:
MediaPipe Team 2022-10-19 09:18:35 -07:00 committed by Copybara-Service
parent 70df9e2419
commit a18f91e04f
2 changed files with 18 additions and 3 deletions

View File

@ -289,6 +289,23 @@ cc_test(
], ],
) )
cc_library(
name = "universal_sentence_encoder_preprocessor_calculator",
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],

View File

@ -88,9 +88,7 @@ RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
kNumInputTensorsForUniversalSentenceEncoder)); kNumInputTensorsForUniversalSentenceEncoder));
} }
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute( return absl::InvalidArgumentError("Expected tensor element type kChar");
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
} }
std::vector<std::string> results; std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {