Move BuildInputImageTensorSpecs to utils

PiperOrigin-RevId: 508829724
This commit is contained in:
MediaPipe Team 2023-02-10 21:54:49 -08:00 committed by Copybara-Service
parent 2c82f67097
commit 626f92caea
6 changed files with 62 additions and 57 deletions

View File

@ -71,32 +71,6 @@ struct ImagePreprocessingOutputStreams {
Source<Image> image;
};
// Builds an ImageTensorSpecs for configuring the preprocessing calculators.
absl::StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
const ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Image tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Image tflite models are assumed to have a single input.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* input_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
vision::GetImageTensorMetadataIfAny(
*model_resources.GetMetadataExtractor(), 0));
return vision::BuildInputImageTensorSpecs(*input_tensor,
image_tensor_metadata);
}
// Fills in the ImageToTensorCalculatorOptions based on the ImageTensorSpecs.
absl::Status ConfigureImageToTensorCalculator(
const ImageTensorSpecs& image_tensor_specs,
@ -150,7 +124,7 @@ absl::Status ConfigureImagePreprocessingGraph(
const ModelResources& model_resources, bool use_gpu,
proto::ImagePreprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources));
vision::BuildInputImageTensorSpecs(model_resources));
MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator(
image_tensor_specs, options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is

View File

@ -109,32 +109,6 @@ absl::Status SanityCheckOptions(
return absl::OkStatus();
}
// Builds an ImageTensorSpecs for configuring the image preprocessing subgraph.
absl::StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
const ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Hand landmark model is assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Hand landmark model is assumed to have a single input.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* input_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
vision::GetImageTensorMetadataIfAny(
*model_resources.GetMetadataExtractor(), 0));
return vision::BuildInputImageTensorSpecs(*input_tensor,
image_tensor_metadata);
}
// Split hand landmark detection model output tensor into four parts,
// representing landmarks, presence scores, handedness, and world landmarks,
// respectively.
@ -297,7 +271,7 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
auto image_size = preprocessing[Output<std::pair<int, int>>("IMAGE_SIZE")];
ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources));
BuildInputImageTensorSpecs(model_resources));
auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph);

View File

@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_test_with_tflite")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "cc_test_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
cc_library_with_tflite(
name = "image_tensor_specs",
srcs = ["image_tensor_specs.cc"],
hdrs = ["image_tensor_specs.h"],
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources",
],
deps = [
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status",
@ -44,11 +47,11 @@ cc_test_with_tflite(
srcs = ["image_tensor_specs_test.cc"],
data = ["//mediapipe/tasks/testdata/vision:test_models"],
tflite_deps = [
":image_tensor_specs",
"//mediapipe/tasks/cc/core:model_resources",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
deps = [
":image_tensor_specs",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",

View File

@ -236,6 +236,32 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
return result;
}
// Builds an ImageTensorSpecs for configuring the preprocessing calculators.
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Image tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->inputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Image tflite models are assumed to have a single input.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* input_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
vision::GetImageTensorMetadataIfAny(
*model_resources.GetMetadataExtractor(), 0));
return vision::BuildInputImageTensorSpecs(*input_tensor,
image_tensor_metadata);
}
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -90,6 +91,11 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const tflite::Tensor& image_tensor,
const tflite::TensorMetadata* image_tensor_metadata);
// Build ImageTensorSpec from model resources. The tflite model must contain
// single subgraph with single input tensor.
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const core::ModelResources& model_resources);
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -171,6 +171,28 @@ TEST_F(ImageTensorSpecsTest,
EXPECT_EQ(input_specs.normalization_options, absl::nullopt);
}
TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsFromModelResources) {
auto model_file = std::make_unique<core::proto::ExternalFile>();
model_file->set_file_name(
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedPartialMetadata));
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
core::ModelResources::Create(kTestModelResourcesTag,
std::move(model_file)));
const tflite::Model* model = model_resources->GetTfLiteModel();
CHECK(model != nullptr);
absl::StatusOr<ImageTensorSpecs> input_specs_or =
BuildInputImageTensorSpecs(*model_resources);
MP_ASSERT_OK(input_specs_or);
const ImageTensorSpecs& input_specs = input_specs_or.value();
EXPECT_EQ(input_specs.image_width, 224);
EXPECT_EQ(input_specs.image_height, 224);
EXPECT_EQ(input_specs.color_space, ColorSpaceType_RGB);
EXPECT_STREQ(EnumNameTensorType(input_specs.tensor_type),
EnumNameTensorType(tflite::TensorType_UINT8));
EXPECT_EQ(input_specs.normalization_options, absl::nullopt);
}
} // namespace
} // namespace vision
} // namespace tasks