Move BuildInputImageTensorSpecs to utils
PiperOrigin-RevId: 508829724
This commit is contained in:
parent
2c82f67097
commit
626f92caea
|
@ -71,32 +71,6 @@ struct ImagePreprocessingOutputStreams {
|
|||
Source<Image> image;
|
||||
};
|
||||
|
||||
// Builds an ImageTensorSpecs for configuring the preprocessing calculators.
|
||||
absl::StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
|
||||
const ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Image tflite models are assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->inputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Image tflite models are assumed to have a single input.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* input_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
|
||||
vision::GetImageTensorMetadataIfAny(
|
||||
*model_resources.GetMetadataExtractor(), 0));
|
||||
return vision::BuildInputImageTensorSpecs(*input_tensor,
|
||||
image_tensor_metadata);
|
||||
}
|
||||
|
||||
// Fills in the ImageToTensorCalculatorOptions based on the ImageTensorSpecs.
|
||||
absl::Status ConfigureImageToTensorCalculator(
|
||||
const ImageTensorSpecs& image_tensor_specs,
|
||||
|
@ -150,7 +124,7 @@ absl::Status ConfigureImagePreprocessingGraph(
|
|||
const ModelResources& model_resources, bool use_gpu,
|
||||
proto::ImagePreprocessingGraphOptions* options) {
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildImageTensorSpecs(model_resources));
|
||||
vision::BuildInputImageTensorSpecs(model_resources));
|
||||
MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator(
|
||||
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
||||
// The GPU backend isn't able to process int data. If the input tensor is
|
||||
|
|
|
@ -109,32 +109,6 @@ absl::Status SanityCheckOptions(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Builds an ImageTensorSpecs for configuring the image preprocessing subgraph.
|
||||
absl::StatusOr<ImageTensorSpecs> BuildImageTensorSpecs(
|
||||
const ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Hand landmark model is assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->inputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Hand landmark model is assumed to have a single input.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* input_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
|
||||
vision::GetImageTensorMetadataIfAny(
|
||||
*model_resources.GetMetadataExtractor(), 0));
|
||||
return vision::BuildInputImageTensorSpecs(*input_tensor,
|
||||
image_tensor_metadata);
|
||||
}
|
||||
|
||||
// Split hand landmark detection model output tensor into four parts,
|
||||
// representing landmarks, presence scores, handedness, and world landmarks,
|
||||
// respectively.
|
||||
|
@ -297,7 +271,7 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
auto image_size = preprocessing[Output<std::pair<int, int>>("IMAGE_SIZE")];
|
||||
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildImageTensorSpecs(model_resources));
|
||||
BuildInputImageTensorSpecs(model_resources));
|
||||
|
||||
auto& inference = AddInference(
|
||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||
|
|
|
@ -12,16 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_test_with_tflite")
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "cc_test_with_tflite")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
cc_library_with_tflite(
|
||||
name = "image_tensor_specs",
|
||||
srcs = ["image_tensor_specs.cc"],
|
||||
hdrs = ["image_tensor_specs.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -44,11 +47,11 @@ cc_test_with_tflite(
|
|||
srcs = ["image_tensor_specs_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/vision:test_models"],
|
||||
tflite_deps = [
|
||||
":image_tensor_specs",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
deps = [
|
||||
":image_tensor_specs",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
|
|
|
@ -236,6 +236,32 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
|||
return result;
|
||||
}
|
||||
|
||||
// Builds an ImageTensorSpecs for configuring the preprocessing calculators.
|
||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
||||
const core::ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Image tflite models are assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->inputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Image tflite models are assumed to have a single input.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* input_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||
ASSIGN_OR_RETURN(const auto* image_tensor_metadata,
|
||||
vision::GetImageTensorMetadataIfAny(
|
||||
*model_resources.GetMetadataExtractor(), 0));
|
||||
return vision::BuildInputImageTensorSpecs(*input_tensor,
|
||||
image_tensor_metadata);
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
|
@ -90,6 +91,11 @@ absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
|||
const tflite::Tensor& image_tensor,
|
||||
const tflite::TensorMetadata* image_tensor_metadata);
|
||||
|
||||
// Build ImageTensorSpec from model resources. The tflite model must contain
|
||||
// single subgraph with single input tensor.
|
||||
absl::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
|
||||
const core::ModelResources& model_resources);
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -171,6 +171,28 @@ TEST_F(ImageTensorSpecsTest,
|
|||
EXPECT_EQ(input_specs.normalization_options, absl::nullopt);
|
||||
}
|
||||
|
||||
TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsFromModelResources) {
|
||||
auto model_file = std::make_unique<core::proto::ExternalFile>();
|
||||
model_file->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedPartialMetadata));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
core::ModelResources::Create(kTestModelResourcesTag,
|
||||
std::move(model_file)));
|
||||
const tflite::Model* model = model_resources->GetTfLiteModel();
|
||||
CHECK(model != nullptr);
|
||||
absl::StatusOr<ImageTensorSpecs> input_specs_or =
|
||||
BuildInputImageTensorSpecs(*model_resources);
|
||||
MP_ASSERT_OK(input_specs_or);
|
||||
|
||||
const ImageTensorSpecs& input_specs = input_specs_or.value();
|
||||
EXPECT_EQ(input_specs.image_width, 224);
|
||||
EXPECT_EQ(input_specs.image_height, 224);
|
||||
EXPECT_EQ(input_specs.color_space, ColorSpaceType_RGB);
|
||||
EXPECT_STREQ(EnumNameTensorType(input_specs.tensor_type),
|
||||
EnumNameTensorType(tflite::TensorType_UINT8));
|
||||
EXPECT_EQ(input_specs.normalization_options, absl::nullopt);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
|
|
Loading…
Reference in New Issue
Block a user