Enable run inference with a TFLite model containing multiple subgraphs. It uses the subgraph 0 as the default primary subgraph for inference. It will also log a warning in the case that there are more than one subgraph in the model.

PiperOrigin-RevId: 555579131
This commit is contained in:
MediaPipe Team 2023-08-10 11:28:16 -07:00 committed by Copybara-Service
parent a9c7e22ca4
commit 91f15d8e4a
2 changed files with 6 additions and 4 deletions

View File

@ -28,6 +28,7 @@ cc_library_with_tflite(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -241,11 +242,11 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs( absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const core::ModelResources& model_resources) { const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel(); const tflite::Model& model = *model_resources.GetTfLiteModel();
// TODO: Investigate if there is any better solutions support
// running inference with multiple subgraphs.
if (model.subgraphs()->size() != 1) { if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload( LOG(WARNING) << "TFLite model has more than 1 subgraphs. Use subrgaph 0 as "
absl::StatusCode::kInvalidArgument, "the primary subgraph for inference";
"Image tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
} }
const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) { if (primary_subgraph->inputs()->size() != 1) {