Enable run inference with a TFLite model containing multiple subgraphs. It uses the subgraph 0 as the default primary subgraph for inference. It will also log a warning in the case that there are more than one subgraph in the model.

PiperOrigin-RevId: 555579131
This commit is contained in:
MediaPipe Team 2023-08-10 11:28:16 -07:00 committed by Copybara-Service
parent a9c7e22ca4
commit 91f15d8e4a
2 changed files with 6 additions and 4 deletions

View File

@ -28,6 +28,7 @@ cc_library_with_tflite(
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/metadata:metadata_extractor",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "flatbuffers/flatbuffers.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -241,11 +242,11 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
// TODO: Investigate if there is any better solutions support
// running inference with multiple subgraphs.
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Image tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
LOG(WARNING) << "TFLite model has more than 1 subgraphs. Use subrgaph 0 as "
"the primary subgraph for inference";
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) {