From 91f15d8e4a83d28559380aa6e9f3d8156079f4f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 10 Aug 2023 11:28:16 -0700 Subject: [PATCH] Enable run inference with a TFLite model containing multiple subgraphs. It uses the subgraph 0 as the default primary subgraph for inference. It will also log a warning in the case that there are more than one subgraph in the model. PiperOrigin-RevId: 555579131 --- mediapipe/tasks/cc/vision/utils/BUILD | 1 + mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index ae303441c..442fd2717 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -28,6 +28,7 @@ cc_library_with_tflite( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/metadata:metadata_extractor", diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc index 1041dd1f9..7d48c6282 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/types/optional.h" #include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -241,11 +242,11 @@ absl::StatusOr BuildInputImageTensorSpecs( absl::StatusOr BuildInputImageTensorSpecs( const core::ModelResources& model_resources) { const tflite::Model& model = *model_resources.GetTfLiteModel(); + // TODO: Investigate if there is any better solutions support + // running inference with multiple subgraphs. if (model.subgraphs()->size() != 1) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Image tflite models are assumed to have a single subgraph.", - MediaPipeTasksStatus::kInvalidArgumentError); + LOG(WARNING) << "TFLite model has more than 1 subgraphs. Use subrgaph 0 as " + "the primary subgraph for inference"; } const auto* primary_subgraph = (*model.subgraphs())[0]; if (primary_subgraph->inputs()->size() != 1) {