diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 4c9c6e69c..b084331c8 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -47,26 +47,35 @@ cc_library( srcs = ["image_segmenter_graph.cc"], deps = [ "//mediapipe/calculators/core:merge_to_vector_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/image:image_transformation_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensor_converter_calculator", + "//mediapipe/calculators/tensor:tensor_converter_calculator_cc_proto", + "//mediapipe/calculators/util:from_image_calculator", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:status", + "//mediapipe/gpu:scale_mode_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 923cf2937..c4a4065c6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -20,22 +20,26 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "mediapipe/calculators/image/image_clone_calculator.pb.h" +#include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -59,13 +63,14 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::vision::image_segmenter::proto:: ImageSegmenterGraphOptions; using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; -using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageCpuTag[] = "IMAGE_CPU"; +constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; @@ -78,6 +83,13 @@ struct ImageSegmenterOutputs { Source image; }; +// Struct holding the image and input tensors after image preprocessing and +// transferred to the requested device. +struct ImageAndTensorsOnDevice { + Source image; + Source> tensors; +}; + } // namespace absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { @@ -144,7 +156,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator( return absl::OkStatus(); } -absl::StatusOr GetOutputTensor( +// Get the output tensor from the tflite model of given model resources. +absl::StatusOr GetOutputTensor( const core::ModelResources& model_resources) { const tflite::Model& model = *model_resources.GetTfLiteModel(); const auto* primary_subgraph = (*model.subgraphs())[0]; @@ -153,6 +166,115 @@ absl::StatusOr GetOutputTensor( return output_tensor; } +// Get the input tensor from the tflite model of given model resources. +absl::StatusOr GetInputTensor( + const core::ModelResources& model_resources) { + const tflite::Model& model = *model_resources.GetTfLiteModel(); + const auto* primary_subgraph = (*model.subgraphs())[0]; + const auto* input_tensor = + (*primary_subgraph->tensors())[(*primary_subgraph->inputs())[0]]; + return input_tensor; +} + +// Configure the ImageTransformationCalculator according to the input tensor. +void ConfigureImageTransformationCalculator( + const tflite::Tensor& tflite_input_tensor, + mediapipe::ImageTransformationCalculatorOptions& options) { + options.set_output_height(tflite_input_tensor.shape()->data()[1]); + options.set_output_width(tflite_input_tensor.shape()->data()[2]); +} + +// Configure the TensorConverterCalculator to convert the image to tensor. +void ConfigureTensorConverterCalculator( + const ImageTensorSpecs& image_tensor_specs, + mediapipe::TensorConverterCalculatorOptions& options) { + float mean = image_tensor_specs.normalization_options->mean_values[0]; + float std = image_tensor_specs.normalization_options->std_values[0]; + options.set_max_num_channels(4); + options.mutable_output_tensor_float_range()->set_min((0.0f - mean) / std); + options.mutable_output_tensor_float_range()->set_max((255.0f - mean) / std); +} + +// Image preprocessing step to convert the given image to the input tensors for +// the tflite model. +absl::StatusOr ConvertImageToTensors( + Source image_in, Source norm_rect_in, bool use_gpu, + const core::ModelResources& model_resources, Graph& graph) { + ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor, + GetInputTensor(model_resources)); + if (tflite_input_tensor->shape()->size() != 4) { + return absl::InvalidArgumentError( + absl::StrFormat("Expect segmentation model has input image tensor to " + "be 4 dims. Got input tensor with " + "dims: %d", + tflite_input_tensor->shape()->size())); + } + const int input_tensor_channel = tflite_input_tensor->shape()->data()[3]; + if (input_tensor_channel != 3 && input_tensor_channel != 4) { + return absl::InvalidArgumentError(absl::StrFormat( + "Expect segmentation model has input image tensor with channels = 3 or " + "4. Get " + "channel = %d", + tflite_input_tensor->shape()->data()[3])); + } else if (input_tensor_channel == 3) { + // ImagePreprocessingGraph is backed by ImageToTensorCalculator which only + // supports Tensor with channel = 3. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions())); + image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); + return {{preprocessing.Out(kImageTag).Cast(), + preprocessing.Out(kTensorsTag).Cast>()}}; + } else { + // TODO Remove legacy preprocessing calculators. + // For segmentation model with input Tensor with channel = 4, use legacy + // TfLite preprocessing calculators + + // Upload image to GPU if requested to use gpu. + auto& image_clone = graph.AddNode("ImageCloneCalculator"); + image_clone.GetOptions() + .set_output_on_gpu(use_gpu); + image_in >> image_clone.In(""); + Source image_on_device = image_clone.Out("").Cast(); + + // Convert from Image to legacy ImageFrame or GpuBuffer. + auto& from_image = graph.AddNode("FromImageCalculator"); + image_on_device >> from_image.In(kImageTag); + auto image_cpu_or_gpu = + from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag); + + // Resize the input image to the model input size. + auto& image_transformation = graph.AddNode("ImageTransformationCalculator"); + ConfigureImageTransformationCalculator( + *tflite_input_tensor, + image_transformation + .GetOptions()); + const absl::string_view image_or_image_gpu_tag = + use_gpu ? kImageGpuTag : kImageTag; + image_cpu_or_gpu >> image_transformation.In(image_or_image_gpu_tag); + auto transformed_image = image_transformation.Out(image_or_image_gpu_tag); + + // Convert image to mediapipe tensor. + auto& tensor_converter = graph.AddNode("TensorConverterCalculator"); + ASSIGN_OR_RETURN(auto image_tensor_specs, + vision::BuildInputImageTensorSpecs(model_resources)); + ConfigureTensorConverterCalculator( + image_tensor_specs, + tensor_converter + .GetOptions()); + + transformed_image >> tensor_converter.In(image_or_image_gpu_tag); + auto tensors = + tensor_converter.Out(kTensorsTag).Cast>(); + + return {{image_on_device, tensors}}; + } +} + // An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic // segmentation. // Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. @@ -244,23 +366,17 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = graph.AddNode( - "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); bool use_gpu = components::processors::DetermineImagePreprocessingGpuBackend( task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( - model_resources, use_gpu, - &preprocessing.GetOptions())); - image_in >> preprocessing.In(kImageTag); - norm_rect_in >> preprocessing.In(kNormRectTag); - + ASSIGN_OR_RETURN(auto image_and_tensors, + ConvertImageToTensors(image_in, norm_rect_in, use_gpu, + model_resources, graph)); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. auto& inference = AddInference( model_resources, task_options.base_options().acceleration(), graph); - preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + image_and_tensors.tensors >> inference.In(kTensorsTag); // Adds segmentation calculators for output streams. auto& tensor_to_images = @@ -283,7 +399,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { segmented_masks.push_back( Source(tensor_to_images[Output(kSegmentationTag)])); } else { - ASSIGN_OR_RETURN(const Tensor* output_tensor, + ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, GetOutputTensor(model_resources)); const int segmentation_streams_num = *output_tensor->shape()->rbegin(); for (int i = 0; i < segmentation_streams_num; ++i) { @@ -291,9 +407,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } - return ImageSegmenterOutputs{ - /*segmented_masks=*/segmented_masks, - /*image=*/preprocessing[Output(kImageTag)]}; + return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks, + /*image=*/image_and_tensors.image}; } }; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index c8c6e9036..d1fe20182 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -27,8 +27,10 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" @@ -59,6 +61,8 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite"; constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite"; +constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite"; + constexpr float kGoldenMaskSimilarity = 0.98; // Magnification factor used when creating the golden category masks to make @@ -87,7 +91,21 @@ Image GetSRGBImage(const std::string& image_path) { cv::Mat image_mat = cv::imread(image_path); mediapipe::ImageFrame image_frame( mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows, - image_mat.step, image_mat.data, [image_mat](uint8[]) {}); + image_mat.step, image_mat.data, [image_mat](uint8_t[]) {}); + Image image(std::make_shared(std::move(image_frame))); + return image; +} + +Image GetSRGBAImage(const std::string& image_path) { + cv::Mat image_mat = cv::imread(image_path); + cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGBA); + std::vector channels(4); + cv::split(image_mat, channels); + channels[3].setTo(0); + cv::merge(channels.data(), 4, image_mat); + mediapipe::ImageFrame image_frame( + mediapipe::ImageFormat::SRGBA, image_mat.cols, image_mat.rows, + image_mat.step, image_mat.data, [image_mat](uint8_t[]) {}); Image image(std::make_shared(std::move(image_frame))); return image; } @@ -202,6 +220,30 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { MediaPipeTasksStatus::kRunnerInitializationError)))); } +TEST_F(CreateFromOptionsTest, FailsWithInputDimsTwoModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, "dense.tflite"); + absl::StatusOr> result = + ImageSegmenter::Create(std::move(options)); + EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(result.status().message(), + HasSubstr("Expect segmentation model has input image tensor to " + "be 4 dims.")); +} + +TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, "conv2d_input_channel_1.tflite"); + absl::StatusOr> result = + ImageSegmenter::Create(std::move(options)); + EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(result.status().message(), + HasSubstr("Expect segmentation model has input image tensor with " + "channels = 3 or 4.")); +} + class ImageModeTest : public tflite_shims::testing::Test {}; TEST_F(ImageModeTest, SucceedsWithCategoryMask) { @@ -369,6 +411,31 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsHairSegmentation) { + Image image = + GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 2); + + cv::Mat hair_mask = mediapipe::formats::MatView( + confidence_masks[1].GetImageFrameSharedPtr().get()); + MP_ASSERT_OK(segmenter->Close()); + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + EXPECT_THAT(hair_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -548,8 +615,6 @@ TEST_F(LiveStreamModeTest, Succeeds) { } } -// TODO: Add test for hair segmentation model. - } // namespace } // namespace image_segmenter } // namespace vision diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 6d6102251..8e8ca2b91 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -36,7 +36,9 @@ mediapipe_files(srcs = [ "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", + "conv2d_input_channel_1.tflite", "deeplabv3.tflite", + "dense.tflite", "face_detection_full_range.tflite", "face_detection_full_range_sparse.tflite", "face_detection_short_range.tflite", @@ -44,6 +46,7 @@ mediapipe_files(srcs = [ "face_landmark_with_attention.tflite", "fist.jpg", "fist.png", + "hair_segmentation.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", @@ -64,6 +67,7 @@ mediapipe_files(srcs = [ "pointing_up.jpg", "pointing_up_rotated.jpg", "portrait.jpg", + "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", "right_hands.jpg", "right_hands_rotated.jpg", @@ -117,6 +121,7 @@ filegroup( "pointing_up.jpg", "pointing_up_rotated.jpg", "portrait.jpg", + "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", "right_hands.jpg", "right_hands_rotated.jpg", @@ -140,12 +145,15 @@ filegroup( "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", + "conv2d_input_channel_1.tflite", "deeplabv3.tflite", + "dense.tflite", "face_detection_full_range.tflite", "face_detection_full_range_sparse.tflite", "face_detection_short_range.tflite", "face_landmark.tflite", "face_landmark_with_attention.tflite", + "hair_segmentation.tflite", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "hand_landmarker.task", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index c4cc31fe2..d14d93169 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -190,6 +190,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"], ) + http_file( + name = "com_google_mediapipe_conv2d_input_channel_1_tflite", + sha256 = "126edac445967799f3b8b124d15483b1506f6d6cb57a501c1636eb8f2fb3734f", + urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1678218348519744"], + ) + http_file( name = "com_google_mediapipe_corrupted_mobilenet_v1_0_25_224_1_default_1_tflite", sha256 = "f0cbeb8061f4c693e20de779ce255af923508492e8a24f6db320845a52facb51", @@ -202,6 +208,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"], ) + http_file( + name = "com_google_mediapipe_dense_tflite", + sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869", + urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"], + ) + http_file( name = "com_google_mediapipe_dummy_gesture_recognizer_task", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", @@ -354,8 +366,8 @@ def external_files(): http_file( name = "com_google_mediapipe_hair_segmentation_tflite", - sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633", - urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1661875756623461"], + sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711", + urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"], ) http_file( @@ -766,6 +778,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) + http_file( + name = "com_google_mediapipe_portrait_expected_blendshapes_with_attention_pbtxt", + sha256 = "0142d56705093c3d79ea5ee79b8e9454499abee00fc059491e6ca14f5fbab862", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_blendshapes_with_attention.pbtxt?generation=1678218364703223"], + ) + http_file( name = "com_google_mediapipe_portrait_expected_detection_pbtxt", sha256 = "ace755f0fd0ba3b2d75e4f8bb1b08d2f65975fd5570898004540dfef735c1c3d", @@ -780,8 +798,14 @@ def external_files(): http_file( name = "com_google_mediapipe_portrait_expected_face_landmarks_with_attention_pbtxt", - sha256 = "f2ccd889654b914996e4aab0d7831a3e73d3b63d6c14f6bac4bec5cd3415bce4", - urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1676415475626542"], + sha256 = "dae959456f001015278f3a1535bd03c9fa0990a3df951135645ce23293be0613", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_landmarks_with_attention.pbtxt?generation=1678218367300928"], + ) + + http_file( + name = "com_google_mediapipe_portrait_hair_expected_mask_jpg", + sha256 = "d9ffc4f2ed0ee2d551d9239942e4dfceebf0c33a56858c84410f32ea4f0c1b2c", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_hair_expected_mask.jpg?generation=1678218370120178"], ) http_file(