Internal change
PiperOrigin-RevId: 529617578
This commit is contained in:
parent
18d893c697
commit
c24e7a250c
|
@ -63,6 +63,8 @@ cc_library(
|
|||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||
"//mediapipe/calculators/image:set_alpha_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensor_converter_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
|
@ -249,7 +250,8 @@ void ConfigureTensorConverterCalculator(
|
|||
// the tflite model.
|
||||
absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||
Source<Image> image_in, Source<NormalizedRect> norm_rect_in, bool use_gpu,
|
||||
const core::ModelResources& model_resources, Graph& graph) {
|
||||
bool is_hair_segmentation, const core::ModelResources& model_resources,
|
||||
Graph& graph) {
|
||||
ASSIGN_OR_RETURN(const tflite::Tensor* tflite_input_tensor,
|
||||
GetInputTensor(model_resources));
|
||||
if (tflite_input_tensor->shape()->size() != 4) {
|
||||
|
@ -294,9 +296,17 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
|||
// Convert from Image to legacy ImageFrame or GpuBuffer.
|
||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
||||
image_on_device >> from_image.In(kImageTag);
|
||||
auto image_cpu_or_gpu =
|
||||
Source<api2::AnyType> image_cpu_or_gpu =
|
||||
from_image.Out(use_gpu ? kImageGpuTag : kImageCpuTag);
|
||||
|
||||
if (is_hair_segmentation) {
|
||||
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
||||
set_alpha.GetOptions<mediapipe::SetAlphaCalculatorOptions>()
|
||||
.set_alpha_value(0);
|
||||
image_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||
image_cpu_or_gpu = set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||
}
|
||||
|
||||
// Resize the input image to the model input size.
|
||||
auto& image_transformation = graph.AddNode("ImageTransformationCalculator");
|
||||
ConfigureImageTransformationCalculator(
|
||||
|
@ -461,22 +471,41 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
ASSIGN_OR_RETURN(auto image_and_tensors,
|
||||
ConvertImageToTensors(image_in, norm_rect_in, use_gpu,
|
||||
model_resources, graph));
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
image_and_tensors.tensors >> inference.In(kTensorsTag);
|
||||
|
||||
// Adds segmentation calculators for output streams.
|
||||
// Adds segmentation calculators for output streams. Add this calculator
|
||||
// first to get the labels.
|
||||
auto& tensor_to_images =
|
||||
graph.AddNode("mediapipe.tasks.TensorsToSegmentationCalculator");
|
||||
RET_CHECK_OK(ConfigureTensorsToSegmentationCalculator(
|
||||
task_options, model_resources,
|
||||
&tensor_to_images
|
||||
.GetOptions<TensorsToSegmentationCalculatorOptions>()));
|
||||
const auto& tensor_to_images_options =
|
||||
tensor_to_images.GetOptions<TensorsToSegmentationCalculatorOptions>();
|
||||
|
||||
// TODO: remove special logic for hair segmentation model.
|
||||
// The alpha channel of hair segmentation model indicates the interested
|
||||
// area. The model was designed for live stream mode, so that the mask of
|
||||
// previous frame is used as the indicator for the next frame. For the first
|
||||
// frame, it expects the alpha channel to be empty. To consolidate IMAGE,
|
||||
// VIDEO and LIVE_STREAM mode in mediapipe tasks, here we forcely set the
|
||||
// alpha channel to be empty if we find the model is the hair segmentation
|
||||
// model.
|
||||
bool is_hair_segmentation = false;
|
||||
if (tensor_to_images_options.label_items_size() == 2 &&
|
||||
tensor_to_images_options.label_items().at(1).name() == "hair") {
|
||||
is_hair_segmentation = true;
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
auto image_and_tensors,
|
||||
ConvertImageToTensors(image_in, norm_rect_in, use_gpu,
|
||||
is_hair_segmentation, model_resources, graph));
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
image_and_tensors.tensors >> inference.In(kTensorsTag);
|
||||
inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag);
|
||||
|
||||
// Adds image property calculator for output size.
|
||||
|
|
8
third_party/external_files.bzl
vendored
8
third_party/external_files.bzl
vendored
|
@ -204,8 +204,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_conv2d_input_channel_1_tflite",
|
||||
sha256 = "126edac445967799f3b8b124d15483b1506f6d6cb57a501c1636eb8f2fb3734f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1678218348519744"],
|
||||
sha256 = "ccb667092f3aed3a35a57fb3478fecc0c8f6360dbf477a9db9c24e5b3ec4273e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/conv2d_input_channel_1.tflite?generation=1683252905577703"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -246,8 +246,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_dense_tflite",
|
||||
sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1678218351373709"],
|
||||
sha256 = "6795e7c3a263f44e97be048a5e1166e0921b453bfbaf037f4f69ac5c059ee945",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/dense.tflite?generation=1683252907920466"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user