diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index ec85ea753..0e3bb7401 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -52,7 +52,7 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { } switch (base_options->delegate) { case BaseOptions::Delegate::CPU: - base_options_proto.mutable_acceleration()->mutable_xnnpack(); + base_options_proto.mutable_acceleration()->mutable_tflite(); break; case BaseOptions::Delegate::GPU: base_options_proto.mutable_acceleration()->mutable_gpu(); diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index c6bc8f69b..90a38747c 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -122,6 +122,9 @@ class InferenceSubgraph : public Subgraph { case Acceleration::kGpu: delegate.mutable_gpu()->CopyFrom(acceleration.gpu()); break; + case Acceleration::kTflite: + delegate.mutable_tflite()->CopyFrom(acceleration.tflite()); + break; case Acceleration::DELEGATE_NOT_SET: // Deafult inference calculator setting. break; @@ -177,9 +180,9 @@ GenericNode& ModelTaskGraph::AddInference( ->CopyFrom(acceleration); // When the model resources tag is available, the ModelResourcesCalculator // will retrieve the cached model resources from the graph service by tag. - // Otherwise, provides the exteranal file and asks the + // Otherwise, provides the external file and asks the // ModelResourcesCalculator to create a local model resources in its - // Calcualtor::Open(). + // Calculator::Open(). if (!model_resources.GetTag().empty()) { inference_subgraph_opts.set_model_resources_tag(model_resources.GetTag()); } else { diff --git a/mediapipe/tasks/cc/core/proto/acceleration.proto b/mediapipe/tasks/cc/core/proto/acceleration.proto index a0575a5d5..bdfaff4d2 100644 --- a/mediapipe/tasks/cc/core/proto/acceleration.proto +++ b/mediapipe/tasks/cc/core/proto/acceleration.proto @@ -32,5 +32,6 @@ message Acceleration { oneof delegate { mediapipe.InferenceCalculatorOptions.Delegate.Xnnpack xnnpack = 1; mediapipe.InferenceCalculatorOptions.Delegate.Gpu gpu = 2; + mediapipe.InferenceCalculatorOptions.Delegate.TfLite tflite = 4; } } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 8a9251152..8db3fa767 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -197,7 +197,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(object_detector.status().message(), - HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); + HasSubstr("interpreter->AllocateTensors() == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 1b569127b..a12e607bb 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -26,6 +26,7 @@ mediapipe_files(srcs = [ "30k-clean.model", "albert_with_metadata.tflite", "bert_text_classifier.tflite", + "mobilebert_embedding_with_metadata.tflite", "mobilebert_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", @@ -83,3 +84,8 @@ filegroup( name = "bert_text_classifier_models", srcs = ["bert_text_classifier.tflite"], ) + +filegroup( + name = "mobilebert_embedding_model", + srcs = ["mobilebert_embedding_with_metadata.tflite"], +) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 8f4b70d38..b1d2c875a 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -340,6 +340,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands.jpg?generation=1661875796949017"], ) + http_file( + name = "com_google_mediapipe_mobilebert_embedding_with_metadata_tflite", + sha256 = "fa47142dcc6f446168bc672f2df9605b6da5d0c0d6264e9be62870282365b95c", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilebert_embedding_with_metadata.tflite?generation=1664516086197724"], + ) + http_file( name = "com_google_mediapipe_mobilebert_vocab_txt", sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",