Support hair segmentation model in image segmenter
PiperOrigin-RevId: 515151150
This commit is contained in:
parent
2fb62e4c29
commit
9f1f4273d0
|
@ -47,26 +47,35 @@ cc_library(
|
||||||
srcs = ["image_segmenter_graph.cc"],
|
srcs = ["image_segmenter_graph.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:merge_to_vector_calculator",
|
"//mediapipe/calculators/core:merge_to_vector_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_properties_calculator",
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator",
|
"//mediapipe/calculators/tensor:inference_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensor_converter_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/util:from_image_calculator",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/gpu:scale_mode_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"//mediapipe/util:label_map_util",
|
"//mediapipe/util:label_map_util",
|
||||||
|
|
|
@ -20,22 +20,26 @@ limitations under the License.
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||||
|
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
|
||||||
|
#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
#include "mediapipe/util/label_map_util.h"
|
#include "mediapipe/util/label_map_util.h"
|
||||||
|
@ -59,13 +63,14 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::
|
using ::mediapipe::tasks::vision::image_segmenter::proto::
|
||||||
ImageSegmenterGraphOptions;
|
ImageSegmenterGraphOptions;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||||
using ::tflite::Tensor;
|
|
||||||
using ::tflite::TensorMetadata;
|
using ::tflite::TensorMetadata;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||||
|
|
||||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||||
|
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||||
|
@ -78,6 +83,13 @@ struct ImageSegmenterOutputs {
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Struct holding the image and input tensors after image preprocessing and
|
||||||
|
// transferred to the requested device.
|
||||||
|
struct ImageAndTensorsOnDevice {
|
||||||
|
Source<Image> image;
|
||||||
|
Source<std::vector<Tensor>> tensors;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
|
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
|
||||||
|
@ -144,7 +156,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<const Tensor*> GetOutputTensor(
|
// Get the output tensor from the tflite model of given model resources.
|
||||||
|
absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
|
||||||
const core::ModelResources& model_resources) {
|
const core::ModelResources& model_resources) {
|
||||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
@ -153,6 +166,115 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
||||||
return output_tensor;
|
return output_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the input tensor from the tflite model of given model resources.
|
||||||
|
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
||||||
|
const core::ModelResources& model_resources) {
|
||||||
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
const auto* input_tensor =
|
||||||
|
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
|
||||||
|
return input_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the ImageTransformationCalculator according to the input tensor.
|
||||||
|
void ConfigureImageTransformationCalculator(
|
||||||
|
const tflite::Tensor& tflite_input_tensor,
|
||||||
|
mediapipe::ImageTransformationCalculatorOptions& options) {
|
||||||
|
options.set_output_height(tflite_input_tensor.shape()->data()[1]);
|
||||||
|
options.set_output_width(tflite_input_tensor.shape()->data()[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the TensorConverterCalculator to convert the image to tensor.
|
||||||
|
void ConfigureTensorConverterCalculator(
|
||||||
|
const ImageTensorSpecs& image_tensor_specs,
|
||||||
|
mediapipe::TensorConverterCalculatorOptions& options) {
|
||||||
|
float mean = image_tensor_specs.normalization_options->mean_values[0];
|
||||||
|
float std = image_tensor_specs.normalization_options->std_values[0];
|
||||||
|
options.set_max_num_channels(4);
|
||||||
|
options.mutable_output_tensor_float_range()->set_min((0.0f - mean) / std);
|
||||||
|
options.mutable_output_tensor_float_range()->set_max((255.0f - mean) / std);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image preprocessing step to convert the given image to the input tensors for
|
||||||
|
// the tflite model.
|
||||||
|
absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
|
Source<Image> image_in, Source<NormalizedRect> norm_rect_in, bool use_gpu,
|
||||||
|
const core::ModelResources& model_resources, Graph& graph) {
|
||||||
|
ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor,
|
||||||
|
GetInputTensor(model_resources));
|
||||||
|
if (tflite_input_tensor->shape()->size() != 4) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
absl::StrFormat("Expect segmentation model has input image tensor to "
|
||||||
|
"be 4 dims. Got input tensor with "
|
||||||
|
"dims: %d",
|
||||||
|
tflite_input_tensor->shape()->size()));
|
||||||
|
}
|
||||||
|
const int input_tensor_channel = tflite_input_tensor->shape()->data()[3];
|
||||||
|
if (input_tensor_channel != 3 && input_tensor_channel != 4) {
|
||||||
|
return absl::InvalidArgumentError(absl::StrFormat(
|
||||||
|
"Expect segmentation model has input image tensor with channels = 3 or "
|
||||||
|
"4. Get "
|
||||||
|
"channel = %d",
|
||||||
|
tflite_input_tensor->shape()->data()[3]));
|
||||||
|
} else if (input_tensor_channel == 3) {
|
||||||
|
// ImagePreprocessingGraph is backed by ImageToTensorCalculator which only
|
||||||
|
// supports Tensor with channel = 3.
|
||||||
|
auto& preprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
||||||
|
model_resources, use_gpu,
|
||||||
|
&preprocessing.GetOptions<tasks::components::processors::proto::
|
||||||
|
ImagePreprocessingGraphOptions>()));
|
||||||
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||||
|
return {{preprocessing.Out(kImageTag).Cast<Image>(),
|
||||||
|
preprocessing.Out(kTensorsTag).Cast<std::vector<Tensor>>()}};
|
||||||
|
} else {
|
||||||
|
// TODO Remove legacy preprocessing calculators.
|
||||||
|
// For segmentation model with input Tensor with channel = 4, use legacy
|
||||||
|
// TfLite preprocessing calculators
|
||||||
|
|
||||||
|
// Upload image to GPU if requested to use gpu.
|
||||||
|
auto& image_clone = graph.AddNode("ImageCloneCalculator");
|
||||||
|
image_clone.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
|
||||||
|
.set_output_on_gpu(use_gpu);
|
||||||
|
image_in >> image_clone.In("");
|
||||||
|
Source<Image> image_on_device = image_clone.Out("").Cast<Image>();
|
||||||
|
|
||||||
|
// Convert from Image to legacy ImageFrame or GpuBuffer.
|
||||||
|
auto& from_image = graph.AddNode("FromImageCalculator");
|
||||||
|
image_on_device >> from_image.In(kImageTag);
|
||||||
|
auto image_cpu_or_gpu =
|
||||||
|
from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag);
|
||||||
|
|
||||||
|
// Resize the input image to the model input size.
|
||||||
|
auto& image_transformation = graph.AddNode("ImageTransformationCalculator");
|
||||||
|
ConfigureImageTransformationCalculator(
|
||||||
|
*tflite_input_tensor,
|
||||||
|
image_transformation
|
||||||
|
.GetOptions<mediapipe::ImageTransformationCalculatorOptions>());
|
||||||
|
const absl::string_view image_or_image_gpu_tag =
|
||||||
|
use_gpu ? kImageGpuTag : kImageTag;
|
||||||
|
image_cpu_or_gpu >> image_transformation.In(image_or_image_gpu_tag);
|
||||||
|
auto transformed_image = image_transformation.Out(image_or_image_gpu_tag);
|
||||||
|
|
||||||
|
// Convert image to mediapipe tensor.
|
||||||
|
auto& tensor_converter = graph.AddNode("TensorConverterCalculator");
|
||||||
|
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||||
|
vision::BuildInputImageTensorSpecs(model_resources));
|
||||||
|
ConfigureTensorConverterCalculator(
|
||||||
|
image_tensor_specs,
|
||||||
|
tensor_converter
|
||||||
|
.GetOptions<mediapipe::TensorConverterCalculatorOptions>());
|
||||||
|
|
||||||
|
transformed_image >> tensor_converter.In(image_or_image_gpu_tag);
|
||||||
|
auto tensors =
|
||||||
|
tensor_converter.Out(kTensorsTag).Cast<std::vector<Tensor>>();
|
||||||
|
|
||||||
|
return {{image_on_device, tensors}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
|
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
|
||||||
// segmentation.
|
// segmentation.
|
||||||
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
|
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
|
||||||
|
@ -244,23 +366,17 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Adds preprocessing calculators and connects them to the graph input image
|
// Adds preprocessing calculators and connects them to the graph input image
|
||||||
// stream.
|
// stream.
|
||||||
auto& preprocessing = graph.AddNode(
|
|
||||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
|
||||||
bool use_gpu =
|
bool use_gpu =
|
||||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||||
task_options.base_options().acceleration());
|
task_options.base_options().acceleration());
|
||||||
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
ASSIGN_OR_RETURN(auto image_and_tensors,
|
||||||
model_resources, use_gpu,
|
ConvertImageToTensors(image_in, norm_rect_in, use_gpu,
|
||||||
&preprocessing.GetOptions<tasks::components::processors::proto::
|
model_resources, graph));
|
||||||
ImagePreprocessingGraphOptions>()));
|
|
||||||
image_in >> preprocessing.In(kImageTag);
|
|
||||||
norm_rect_in >> preprocessing.In(kNormRectTag);
|
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the ImageToTensorCalculator.
|
// tensors produced by the ImageToTensorCalculator.
|
||||||
auto& inference = AddInference(
|
auto& inference = AddInference(
|
||||||
model_resources, task_options.base_options().acceleration(), graph);
|
model_resources, task_options.base_options().acceleration(), graph);
|
||||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
image_and_tensors.tensors >> inference.In(kTensorsTag);
|
||||||
|
|
||||||
// Adds segmentation calculators for output streams.
|
// Adds segmentation calculators for output streams.
|
||||||
auto& tensor_to_images =
|
auto& tensor_to_images =
|
||||||
|
@ -283,7 +399,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
segmented_masks.push_back(
|
segmented_masks.push_back(
|
||||||
Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)]));
|
Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)]));
|
||||||
} else {
|
} else {
|
||||||
ASSIGN_OR_RETURN(const Tensor* output_tensor,
|
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
||||||
GetOutputTensor(model_resources));
|
GetOutputTensor(model_resources));
|
||||||
const int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
const int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||||
for (int i = 0; i < segmentation_streams_num; ++i) {
|
for (int i = 0; i < segmentation_streams_num; ++i) {
|
||||||
|
@ -291,9 +407,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ImageSegmenterOutputs{
|
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||||
/*segmented_masks=*/segmented_masks,
|
/*image=*/image_and_tensors.image};
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
@ -59,6 +61,8 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite";
|
||||||
|
|
||||||
constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite";
|
constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite";
|
||||||
|
|
||||||
|
constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite";
|
||||||
|
|
||||||
constexpr float kGoldenMaskSimilarity = 0.98;
|
constexpr float kGoldenMaskSimilarity = 0.98;
|
||||||
|
|
||||||
// Magnification factor used when creating the golden category masks to make
|
// Magnification factor used when creating the golden category masks to make
|
||||||
|
@ -87,7 +91,21 @@ Image GetSRGBImage(const std::string& image_path) {
|
||||||
cv::Mat image_mat = cv::imread(image_path);
|
cv::Mat image_mat = cv::imread(image_path);
|
||||||
mediapipe::ImageFrame image_frame(
|
mediapipe::ImageFrame image_frame(
|
||||||
mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows,
|
mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows,
|
||||||
image_mat.step, image_mat.data, [image_mat](uint8[]) {});
|
image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
|
||||||
|
Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame)));
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
Image GetSRGBAImage(const std::string& image_path) {
|
||||||
|
cv::Mat image_mat = cv::imread(image_path);
|
||||||
|
cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGBA);
|
||||||
|
std::vector<cv::Mat> channels(4);
|
||||||
|
cv::split(image_mat, channels);
|
||||||
|
channels[3].setTo(0);
|
||||||
|
cv::merge(channels.data(), 4, image_mat);
|
||||||
|
mediapipe::ImageFrame image_frame(
|
||||||
|
mediapipe::ImageFormat::SRGBA, image_mat.cols, image_mat.rows,
|
||||||
|
image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
|
||||||
Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame)));
|
Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame)));
|
||||||
return image;
|
return image;
|
||||||
}
|
}
|
||||||
|
@ -202,6 +220,30 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithInputDimsTwoModel) {
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, "dense.tflite");
|
||||||
|
absl::StatusOr<std::unique_ptr<ImageSegmenter>> result =
|
||||||
|
ImageSegmenter::Create(std::move(options));
|
||||||
|
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(result.status().message(),
|
||||||
|
HasSubstr("Expect segmentation model has input image tensor to "
|
||||||
|
"be 4 dims."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, "conv2d_input_channel_1.tflite");
|
||||||
|
absl::StatusOr<std::unique_ptr<ImageSegmenter>> result =
|
||||||
|
ImageSegmenter::Create(std::move(options));
|
||||||
|
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(result.status().message(),
|
||||||
|
HasSubstr("Expect segmentation model has input image tensor with "
|
||||||
|
"channels = 3 or 4."));
|
||||||
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
|
@ -369,6 +411,31 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
|
Image image =
|
||||||
|
GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
|
||||||
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||||
|
EXPECT_EQ(confidence_masks.size(), 2);
|
||||||
|
|
||||||
|
cv::Mat hair_mask = mediapipe::formats::MatView(
|
||||||
|
confidence_masks[1].GetImageFrameSharedPtr().get());
|
||||||
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
cv::Mat expected_mask = cv::imread(
|
||||||
|
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
|
||||||
|
cv::IMREAD_GRAYSCALE);
|
||||||
|
cv::Mat expected_mask_float;
|
||||||
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
EXPECT_THAT(hair_mask,
|
||||||
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
|
@ -548,8 +615,6 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add test for hair segmentation model.
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace image_segmenter
|
} // namespace image_segmenter
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
|
|
8
mediapipe/tasks/testdata/vision/BUILD
vendored
8
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -36,7 +36,9 @@ mediapipe_files(srcs = [
|
||||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||||
|
"conv2d_input_channel_1.tflite",
|
||||||
"deeplabv3.tflite",
|
"deeplabv3.tflite",
|
||||||
|
"dense.tflite",
|
||||||
"face_detection_full_range.tflite",
|
"face_detection_full_range.tflite",
|
||||||
"face_detection_full_range_sparse.tflite",
|
"face_detection_full_range_sparse.tflite",
|
||||||
"face_detection_short_range.tflite",
|
"face_detection_short_range.tflite",
|
||||||
|
@ -44,6 +46,7 @@ mediapipe_files(srcs = [
|
||||||
"face_landmark_with_attention.tflite",
|
"face_landmark_with_attention.tflite",
|
||||||
"fist.jpg",
|
"fist.jpg",
|
||||||
"fist.png",
|
"fist.png",
|
||||||
|
"hair_segmentation.tflite",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
"hand_landmark_lite.tflite",
|
"hand_landmark_lite.tflite",
|
||||||
"hand_landmarker.task",
|
"hand_landmarker.task",
|
||||||
|
@ -64,6 +67,7 @@ mediapipe_files(srcs = [
|
||||||
"pointing_up.jpg",
|
"pointing_up.jpg",
|
||||||
"pointing_up_rotated.jpg",
|
"pointing_up_rotated.jpg",
|
||||||
"portrait.jpg",
|
"portrait.jpg",
|
||||||
|
"portrait_hair_expected_mask.jpg",
|
||||||
"portrait_rotated.jpg",
|
"portrait_rotated.jpg",
|
||||||
"right_hands.jpg",
|
"right_hands.jpg",
|
||||||
"right_hands_rotated.jpg",
|
"right_hands_rotated.jpg",
|
||||||
|
@ -117,6 +121,7 @@ filegroup(
|
||||||
"pointing_up.jpg",
|
"pointing_up.jpg",
|
||||||
"pointing_up_rotated.jpg",
|
"pointing_up_rotated.jpg",
|
||||||
"portrait.jpg",
|
"portrait.jpg",
|
||||||
|
"portrait_hair_expected_mask.jpg",
|
||||||
"portrait_rotated.jpg",
|
"portrait_rotated.jpg",
|
||||||
"right_hands.jpg",
|
"right_hands.jpg",
|
||||||
"right_hands_rotated.jpg",
|
"right_hands_rotated.jpg",
|
||||||
|
@ -140,12 +145,15 @@ filegroup(
|
||||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||||
|
"conv2d_input_channel_1.tflite",
|
||||||
"deeplabv3.tflite",
|
"deeplabv3.tflite",
|
||||||
|
"dense.tflite",
|
||||||
"face_detection_full_range.tflite",
|
"face_detection_full_range.tflite",
|
||||||
"face_detection_full_range_sparse.tflite",
|
"face_detection_full_range_sparse.tflite",
|
||||||
"face_detection_short_range.tflite",
|
"face_detection_short_range.tflite",
|
||||||
"face_landmark.tflite",
|
"face_landmark.tflite",
|
||||||
"face_landmark_with_attention.tflite",
|
"face_landmark_with_attention.tflite",
|
||||||
|
"hair_segmentation.tflite",
|
||||||
"hand_landmark_full.tflite",
|
"hand_landmark_full.tflite",
|
||||||
"hand_landmark_lite.tflite",
|
"hand_landmark_lite.tflite",
|
||||||
"hand_landmarker.task",
|
"hand_landmarker.task",
|
||||||
|
|
32
third_party/external_files.bzl
vendored
32
third_party/external_files.bzl
vendored
|
@ -190,6 +190,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_conv2d_input_channel_1_tflite",
|
||||||
|
sha256 = "126edac445967799f3b8b124d15483b1506f6d6cb57a501c1636eb8f2fb3734f",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1678218348519744"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite",
|
name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite",
|
||||||
sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51",
|
sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51",
|
||||||
|
@ -202,6 +208,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_dense_tflite",
|
||||||
|
sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
|
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
|
||||||
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
|
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
|
||||||
|
@ -354,8 +366,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_hair_segmentation_tflite",
|
name = "com_google_mediapipe_hair_segmentation_tflite",
|
||||||
sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633",
|
sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1661875756623461"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -766,6 +778,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_portrait_expected_blendshapes_with_attention_pbtxt",
|
||||||
|
sha256 = "0142d56705093c3d79ea5ee79b8e9454499abee00fc059491e6ca14f5fbab862",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_blendshapes_with_attention.pbtxt?generation=1678218364703223"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_portrait_expected_detection_pbtxt",
|
name = "com_google_mediapipe_portrait_expected_detection_pbtxt",
|
||||||
sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d",
|
sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d",
|
||||||
|
@ -780,8 +798,14 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt",
|
name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt",
|
||||||
sha256 = "f2ccd889654b914996e4aab0d7831a3e73d3b63d6c14f6bac4bec5cd3415bce4",
|
sha256 = "dae959456f001015278f3a1535bd03c9fa0990a3df951135645ce23293be0613",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1676415475626542"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1678218367300928"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_portrait_hair_expected_mask_jpg",
|
||||||
|
sha256 = "d9ffc4f2ed0ee2d551d9239942e4dfceebf0c33a56858c84410f32ea4f0c1b2c",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_hair_expected_mask.jpg?generation=1678218370120178"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user