Support hair segmentation model in image segmenter

PiperOrigin-RevId: 515151150
This commit is contained in:
MediaPipe Team 2023-03-08 14:58:52 -08:00 committed by Copybara-Service
parent 2fb62e4c29
commit 9f1f4273d0
5 changed files with 246 additions and 25 deletions

View File

@ -47,26 +47,35 @@ cc_library(
srcs = ["image_segmenter_graph.cc"], srcs = ["image_segmenter_graph.cc"],
deps = [ deps = [
"//mediapipe/calculators/core:merge_to_vector_calculator", "//mediapipe/calculators/core:merge_to_vector_calculator",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:image_transformation_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator_cc_proto",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/gpu:scale_mode_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util", "//mediapipe/util:label_map_util",

View File

@ -20,22 +20,26 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h" #include "mediapipe/util/label_map_util.h"
@ -59,13 +63,14 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::mediapipe::tasks::vision::image_segmenter::proto:: using ::mediapipe::tasks::vision::image_segmenter::proto::
ImageSegmenterGraphOptions; ImageSegmenterGraphOptions;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ::tflite::Tensor;
using ::tflite::TensorMetadata; using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
@ -78,6 +83,13 @@ struct ImageSegmenterOutputs {
Source<Image> image; Source<Image> image;
}; };
// Struct holding the image and input tensors after image preprocessing and
// transferred to the requested device.
struct ImageAndTensorsOnDevice {
Source<Image> image;
Source<std::vector<Tensor>> tensors;
};
} // namespace } // namespace
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
@ -144,7 +156,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::StatusOr<const Tensor*> GetOutputTensor( // Get the output tensor from the tflite model of given model resources.
absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
const core::ModelResources& model_resources) { const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel(); const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0]; const auto* primary_subgraph = (*model.subgraphs())[0];
@ -153,6 +166,115 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
return output_tensor; return output_tensor;
} }
// Get the input tensor from the tflite model of given model resources.
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
const auto* input_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]];
return input_tensor;
}
// Configure the ImageTransformationCalculator according to the input tensor.
void ConfigureImageTransformationCalculator(
const tflite::Tensor& tflite_input_tensor,
mediapipe::ImageTransformationCalculatorOptions& options) {
options.set_output_height(tflite_input_tensor.shape()->data()[1]);
options.set_output_width(tflite_input_tensor.shape()->data()[2]);
}
// Configure the TensorConverterCalculator to convert the image to tensor.
void ConfigureTensorConverterCalculator(
const ImageTensorSpecs& image_tensor_specs,
mediapipe::TensorConverterCalculatorOptions& options) {
float mean = image_tensor_specs.normalization_options->mean_values[0];
float std = image_tensor_specs.normalization_options->std_values[0];
options.set_max_num_channels(4);
options.mutable_output_tensor_float_range()->set_min((0.0f - mean) / std);
options.mutable_output_tensor_float_range()->set_max((255.0f - mean) / std);
}
// Image preprocessing step to convert the given image to the input tensors for
// the tflite model.
absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
Source<Image> image_in, Source<NormalizedRect> norm_rect_in, bool use_gpu,
const core::ModelResources& model_resources, Graph& graph) {
ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor,
GetInputTensor(model_resources));
if (tflite_input_tensor->shape()->size() != 4) {
return absl::InvalidArgumentError(
absl::StrFormat("Expect segmentation model has input image tensor to "
"be 4 dims. Got input tensor with "
"dims: %d",
tflite_input_tensor->shape()->size()));
}
const int input_tensor_channel = tflite_input_tensor->shape()->data()[3];
if (input_tensor_channel != 3 && input_tensor_channel != 4) {
return absl::InvalidArgumentError(absl::StrFormat(
"Expect segmentation model has input image tensor with channels = 3 or "
"4. Get "
"channel = %d",
tflite_input_tensor->shape()->data()[3]));
} else if (input_tensor_channel == 3) {
// ImagePreprocessingGraph is backed by ImageToTensorCalculator which only
// supports Tensor with channel = 3.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto::
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
return {{preprocessing.Out(kImageTag).Cast<Image>(),
preprocessing.Out(kTensorsTag).Cast<std::vector<Tensor>>()}};
} else {
// TODO Remove legacy preprocessing calculators.
// For segmentation model with input Tensor with channel = 4, use legacy
// TfLite preprocessing calculators
// Upload image to GPU if requested to use gpu.
auto& image_clone = graph.AddNode("ImageCloneCalculator");
image_clone.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
.set_output_on_gpu(use_gpu);
image_in >> image_clone.In("");
Source<Image> image_on_device = image_clone.Out("").Cast<Image>();
// Convert from Image to legacy ImageFrame or GpuBuffer.
auto& from_image = graph.AddNode("FromImageCalculator");
image_on_device >> from_image.In(kImageTag);
auto image_cpu_or_gpu =
from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag);
// Resize the input image to the model input size.
auto& image_transformation = graph.AddNode("ImageTransformationCalculator");
ConfigureImageTransformationCalculator(
*tflite_input_tensor,
image_transformation
.GetOptions<mediapipe::ImageTransformationCalculatorOptions>());
const absl::string_view image_or_image_gpu_tag =
use_gpu ? kImageGpuTag : kImageTag;
image_cpu_or_gpu >> image_transformation.In(image_or_image_gpu_tag);
auto transformed_image = image_transformation.Out(image_or_image_gpu_tag);
// Convert image to mediapipe tensor.
auto& tensor_converter = graph.AddNode("TensorConverterCalculator");
ASSIGN_OR_RETURN(auto image_tensor_specs,
vision::BuildInputImageTensorSpecs(model_resources));
ConfigureTensorConverterCalculator(
image_tensor_specs,
tensor_converter
.GetOptions<mediapipe::TensorConverterCalculatorOptions>());
transformed_image >> tensor_converter.In(image_or_image_gpu_tag);
auto tensors =
tensor_converter.Out(kTensorsTag).Cast<std::vector<Tensor>>();
return {{image_on_device, tensors}};
}
}
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic // An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
// segmentation. // segmentation.
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. // Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
@ -244,23 +366,17 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
// stream. // stream.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu = bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend( components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration()); task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( ASSIGN_OR_RETURN(auto image_and_tensors,
model_resources, use_gpu, ConvertImageToTensors(image_in, norm_rect_in, use_gpu,
&preprocessing.GetOptions<tasks::components::processors::proto:: model_resources, graph));
ImagePreprocessingGraphOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator. // tensors produced by the ImageToTensorCalculator.
auto& inference = AddInference( auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph); model_resources, task_options.base_options().acceleration(), graph);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); image_and_tensors.tensors >> inference.In(kTensorsTag);
// Adds segmentation calculators for output streams. // Adds segmentation calculators for output streams.
auto& tensor_to_images = auto& tensor_to_images =
@ -283,7 +399,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
segmented_masks.push_back( segmented_masks.push_back(
Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)])); Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)]));
} else { } else {
ASSIGN_OR_RETURN(const Tensor* output_tensor, ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
GetOutputTensor(model_resources)); GetOutputTensor(model_resources));
const int segmentation_streams_num = *output_tensor->shape()->rbegin(); const int segmentation_streams_num = *output_tensor->shape()->rbegin();
for (int i = 0; i < segmentation_streams_num; ++i) { for (int i = 0; i < segmentation_streams_num; ++i) {
@ -291,9 +407,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i])); tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
} }
} }
return ImageSegmenterOutputs{ return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*segmented_masks=*/segmented_masks, /*image=*/image_and_tensors.image};
/*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -27,8 +27,10 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -59,6 +61,8 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite";
constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite"; constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite";
constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite";
constexpr float kGoldenMaskSimilarity = 0.98; constexpr float kGoldenMaskSimilarity = 0.98;
// Magnification factor used when creating the golden category masks to make // Magnification factor used when creating the golden category masks to make
@ -87,7 +91,21 @@ Image GetSRGBImage(const std::string& image_path) {
cv::Mat image_mat = cv::imread(image_path); cv::Mat image_mat = cv::imread(image_path);
mediapipe::ImageFrame image_frame( mediapipe::ImageFrame image_frame(
mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows, mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows,
image_mat.step, image_mat.data, [image_mat](uint8[]) {}); image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame)));
return image;
}
Image GetSRGBAImage(const std::string& image_path) {
cv::Mat image_mat = cv::imread(image_path);
cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGBA);
std::vector<cv::Mat> channels(4);
cv::split(image_mat, channels);
channels[3].setTo(0);
cv::merge(channels.data(), 4, image_mat);
mediapipe::ImageFrame image_frame(
mediapipe::ImageFormat::SRGBA, image_mat.cols, image_mat.rows,
image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame))); Image image(std::make_shared<mediapipe::ImageFrame>(std::move(image_frame)));
return image; return image;
} }
@ -202,6 +220,30 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));
} }
TEST_F(CreateFromOptionsTest, FailsWithInputDimsTwoModel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, "dense.tflite");
absl::StatusOr<std::unique_ptr<ImageSegmenter>> result =
ImageSegmenter::Create(std::move(options));
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(result.status().message(),
HasSubstr("Expect segmentation model has input image tensor to "
"be 4 dims."));
}
TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, "conv2d_input_channel_1.tflite");
absl::StatusOr<std::unique_ptr<ImageSegmenter>> result =
ImageSegmenter::Create(std::move(options));
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(result.status().message(),
HasSubstr("Expect segmentation model has input image tensor with "
"channels = 3 or 4."));
}
class ImageModeTest : public tflite_shims::testing::Test {}; class ImageModeTest : public tflite_shims::testing::Test {};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) { TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
@ -369,6 +411,31 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
TEST_F(ImageModeTest, SucceedsHairSegmentation) {
Image image =
GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 2);
cv::Mat hair_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
EXPECT_THAT(hair_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
class VideoModeTest : public tflite_shims::testing::Test {}; class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
@ -548,8 +615,6 @@ TEST_F(LiveStreamModeTest, Succeeds) {
} }
} }
// TODO: Add test for hair segmentation model.
} // namespace } // namespace
} // namespace image_segmenter } // namespace image_segmenter
} // namespace vision } // namespace vision

View File

@ -36,7 +36,9 @@ mediapipe_files(srcs = [
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"conv2d_input_channel_1.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"dense.tflite",
"face_detection_full_range.tflite", "face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite", "face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite", "face_detection_short_range.tflite",
@ -44,6 +46,7 @@ mediapipe_files(srcs = [
"face_landmark_with_attention.tflite", "face_landmark_with_attention.tflite",
"fist.jpg", "fist.jpg",
"fist.png", "fist.png",
"hair_segmentation.tflite",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task", "hand_landmarker.task",
@ -64,6 +67,7 @@ mediapipe_files(srcs = [
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
@ -117,6 +121,7 @@ filegroup(
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg", "pointing_up_rotated.jpg",
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
@ -140,12 +145,15 @@ filegroup(
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"conv2d_input_channel_1.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"dense.tflite",
"face_detection_full_range.tflite", "face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite", "face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite", "face_detection_short_range.tflite",
"face_landmark.tflite", "face_landmark.tflite",
"face_landmark_with_attention.tflite", "face_landmark_with_attention.tflite",
"hair_segmentation.tflite",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task", "hand_landmarker.task",

View File

@ -190,6 +190,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"], urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"],
) )
http_file(
name = "com_google_mediapipe_conv2d_input_channel_1_tflite",
sha256 = "126edac445967799f3b8b124d15483b1506f6d6cb57a501c1636eb8f2fb3734f",
urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1678218348519744"],
)
http_file( http_file(
name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite", name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite",
sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51", sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51",
@ -202,6 +208,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"], urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
) )
http_file(
name = "com_google_mediapipe_dense_tflite",
sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869",
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"],
)
http_file( http_file(
name = "com_google_mediapipe_dummy_gesture_recognizer_task", name = "com_google_mediapipe_dummy_gesture_recognizer_task",
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
@ -354,8 +366,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_hair_segmentation_tflite", name = "com_google_mediapipe_hair_segmentation_tflite",
sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633", sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711",
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1661875756623461"], urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"],
) )
http_file( http_file(
@ -766,6 +778,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
) )
http_file(
name = "com_google_mediapipe_portrait_expected_blendshapes_with_attention_pbtxt",
sha256 = "0142d56705093c3d79ea5ee79b8e9454499abee00fc059491e6ca14f5fbab862",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_blendshapes_with_attention.pbtxt?generation=1678218364703223"],
)
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_detection_pbtxt", name = "com_google_mediapipe_portrait_expected_detection_pbtxt",
sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d", sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d",
@ -780,8 +798,14 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt", name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt",
sha256 = "f2ccd889654b914996e4aab0d7831a3e73d3b63d6c14f6bac4bec5cd3415bce4", sha256 = "dae959456f001015278f3a1535bd03c9fa0990a3df951135645ce23293be0613",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1676415475626542"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1678218367300928"],
)
http_file(
name = "com_google_mediapipe_portrait_hair_expected_mask_jpg",
sha256 = "d9ffc4f2ed0ee2d551d9239942e4dfceebf0c33a56858c84410f32ea4f0c1b2c",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_hair_expected_mask.jpg?generation=1678218370120178"],
) )
http_file( http_file(