diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 183b1bb86..fc977c0b5 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -63,6 +63,8 @@ cc_library( "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/image:image_transformation_calculator_cc_proto", + "//mediapipe/calculators/image:set_alpha_calculator", + "//mediapipe/calculators/image:set_alpha_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index a52d3fa9a..6ecfa3685 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/calculators/image/set_alpha_calculator.pb.h" #include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" @@ -249,7 +250,8 @@ void ConfigureTensorConverterCalculator( // the tflite model. absl::StatusOr ConvertImageToTensors( Source image_in, Source norm_rect_in, bool use_gpu, - const core::ModelResources& model_resources, Graph& graph) { + bool is_hair_segmentation, const core::ModelResources& model_resources, + Graph& graph) { ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor, GetInputTensor(model_resources)); if (tflite_input_tensor->shape()->size() != 4) { @@ -294,9 +296,17 @@ absl::StatusOr ConvertImageToTensors( // Convert from Image to legacy ImageFrame or GpuBuffer. auto& from_image = graph.AddNode("FromImageCalculator"); image_on_device >> from_image.In(kImageTag); - auto image_cpu_or_gpu = + Source image_cpu_or_gpu = from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag); + if (is_hair_segmentation) { + auto& set_alpha = graph.AddNode("SetAlphaCalculator"); + set_alpha.GetOptions() + .set_alpha_value(0); + image_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + image_cpu_or_gpu = set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + } + // Resize the input image to the model input size. auto& image_transformation = graph.AddNode("ImageTransformationCalculator"); ConfigureImageTransformationCalculator( @@ -461,22 +471,41 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { bool use_gpu = components::processors::DetermineImagePreprocessingGpuBackend( task_options.base_options().acceleration()); - ASSIGN_OR_RETURN(auto image_and_tensors, - ConvertImageToTensors(image_in, norm_rect_in, use_gpu, - model_resources, graph)); - // Adds inference subgraph and connects its input stream to the output - // tensors produced by the ImageToTensorCalculator. - auto& inference = AddInference( - model_resources, task_options.base_options().acceleration(), graph); - image_and_tensors.tensors >> inference.In(kTensorsTag); - // Adds segmentation calculators for output streams. + // Adds segmentation calculators for output streams. Add this calculator + // first to get the labels. auto& tensor_to_images = graph.AddNode("mediapipe.tasks.TensorsToSegmentationCalculator"); RET_CHECK_OK(ConfigureTensorsToSegmentationCalculator( task_options, model_resources, &tensor_to_images .GetOptions())); + const auto& tensor_to_images_options = + tensor_to_images.GetOptions(); + + // TODO: remove special logic for hair segmentation model. + // The alpha channel of hair segmentation model indicates the interested + // area. The model was designed for live stream mode, so that the mask of + // previous frame is used as the indicator for the next frame. For the first + // frame, it expects the alpha channel to be empty. To consolidate IMAGE, + // VIDEO and LIVE_STREAM mode in mediapipe tasks, here we forcely set the + // alpha channel to be empty if we find the model is the hair segmentation + // model. + bool is_hair_segmentation = false; + if (tensor_to_images_options.label_items_size() == 2 && + tensor_to_images_options.label_items().at(1).name() == "hair") { + is_hair_segmentation = true; + } + + ASSIGN_OR_RETURN( + auto image_and_tensors, + ConvertImageToTensors(image_in, norm_rect_in, use_gpu, + is_hair_segmentation, model_resources, graph)); + // Adds inference subgraph and connects its input stream to the output + // tensors produced by the ImageToTensorCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + image_and_tensors.tensors >> inference.In(kTensorsTag); inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag); // Adds image property calculator for output size. diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index af9361bb3..599248f48 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -204,8 +204,8 @@ def external_files(): http_file( name = "com_google_mediapipe_conv2d_input_channel_1_tflite", - sha256 = "126edac445967799f3b8b124d15483b1506f6d6cb57a501c1636eb8f2fb3734f", - urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1678218348519744"], + sha256 = "ccb667092f3aed3a35a57fb3478fecc0c8f6360dbf477a9db9c24e5b3ec4273e", + urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1683252905577703"], ) http_file( @@ -246,8 +246,8 @@ def external_files(): http_file( name = "com_google_mediapipe_dense_tflite", - sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869", - urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"], + sha256 = "6795e7c3a263f44e97be048a5e1166e0921b453bfbaf037f4f69ac5c059ee945", + urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1683252907920466"], ) http_file(