Image segmenter output both confidence masks and category mask optionally.
PiperOrigin-RevId: 522227345
This commit is contained in:
parent
7fe87936e5
commit
d5def9e24d
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
@ -210,8 +211,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Converts Tensors from a vector of Tensor to Segmentation masks. The
|
// Converts Tensors from a vector of Tensor to Segmentation masks. The
|
||||||
// calculator always output confidence masks, and an optional category mask if
|
// calculator can output optional confidence masks if CONFIDENCE_MASK is
|
||||||
// CATEGORY_MASK is connected.
|
// connected, and an optional category mask if CATEGORY_MASK is connected. At
|
||||||
|
// least one of CONFIDENCE_MASK and CATEGORY_MASK must be connected.
|
||||||
//
|
//
|
||||||
// Performs optional resizing to OUTPUT_SIZE dimension if provided,
|
// Performs optional resizing to OUTPUT_SIZE dimension if provided,
|
||||||
// otherwise the segmented masks is the same size as input tensor.
|
// otherwise the segmented masks is the same size as input tensor.
|
||||||
|
@ -296,6 +298,13 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
||||||
SegmenterOptions::UNSPECIFIED)
|
SegmenterOptions::UNSPECIFIED)
|
||||||
<< "Must specify output_type as one of "
|
<< "Must specify output_type as one of "
|
||||||
"[CONFIDENCE_MASK|CATEGORY_MASK].";
|
"[CONFIDENCE_MASK|CATEGORY_MASK].";
|
||||||
|
} else {
|
||||||
|
if (!cc->Outputs().HasTag("CONFIDENCE_MASK") &&
|
||||||
|
!cc->Outputs().HasTag("CATEGORY_MASK")) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"At least one of CONFIDENCE_MASK and CATEGORY_MASK must be "
|
||||||
|
"connected.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#ifdef __EMSCRIPTEN__
|
#ifdef __EMSCRIPTEN__
|
||||||
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
|
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
|
||||||
|
@ -366,8 +375,9 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Image> confidence_masks =
|
if (cc->Outputs().HasTag("CONFIDENCE_MASK")) {
|
||||||
ProcessForConfidenceMaskCpu(input_shape,
|
std::vector<Image> confidence_masks = ProcessForConfidenceMaskCpu(
|
||||||
|
input_shape,
|
||||||
{/* height= */ output_height,
|
{/* height= */ output_height,
|
||||||
/* width= */ output_width,
|
/* width= */ output_width,
|
||||||
/* channels= */ input_shape.channels},
|
/* channels= */ input_shape.channels},
|
||||||
|
@ -375,6 +385,7 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
for (int i = 0; i < confidence_masks.size(); ++i) {
|
for (int i = 0; i < confidence_masks.size(); ++i) {
|
||||||
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
|
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (cc->Outputs().HasTag("CATEGORY_MASK")) {
|
if (cc->Outputs().HasTag("CATEGORY_MASK")) {
|
||||||
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(
|
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(
|
||||||
input_shape,
|
input_shape,
|
||||||
|
|
|
@ -60,15 +60,19 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
||||||
bool output_category_mask, bool enable_flow_limiting) {
|
bool output_confidence_masks, bool output_category_mask,
|
||||||
|
bool enable_flow_limiting) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||||
options.get());
|
options.get());
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >>
|
if (output_confidence_masks) {
|
||||||
|
task_subgraph.Out(kConfidenceMasksTag)
|
||||||
|
.SetName(kConfidenceMasksStreamName) >>
|
||||||
graph.Out(kConfidenceMasksTag);
|
graph.Out(kConfidenceMasksTag);
|
||||||
|
}
|
||||||
if (output_category_mask) {
|
if (output_category_mask) {
|
||||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
graph.Out(kCategoryMaskTag);
|
graph.Out(kCategoryMaskTag);
|
||||||
|
@ -135,11 +139,17 @@ absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
std::unique_ptr<ImageSegmenterOptions> options) {
|
std::unique_ptr<ImageSegmenterOptions> options) {
|
||||||
|
if (!options->output_confidence_masks && !options->output_category_mask) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"At least one of `output_confidence_masks` and `output_category_mask` "
|
||||||
|
"must be set.");
|
||||||
|
}
|
||||||
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||||
if (options->result_callback) {
|
if (options->result_callback) {
|
||||||
auto result_callback = options->result_callback;
|
auto result_callback = options->result_callback;
|
||||||
bool output_category_mask = options->output_category_mask;
|
bool output_category_mask = options->output_category_mask;
|
||||||
|
bool output_confidence_masks = options->output_confidence_masks;
|
||||||
packets_callback =
|
packets_callback =
|
||||||
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
if (!status_or_packets.ok()) {
|
if (!status_or_packets.ok()) {
|
||||||
|
@ -151,8 +161,12 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Packet confidence_masks =
|
std::optional<std::vector<Image>> confidence_masks;
|
||||||
status_or_packets.value()[kConfidenceMasksStreamName];
|
if (output_confidence_masks) {
|
||||||
|
confidence_masks =
|
||||||
|
status_or_packets.value()[kConfidenceMasksStreamName]
|
||||||
|
.Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
if (output_category_mask) {
|
if (output_category_mask) {
|
||||||
category_mask =
|
category_mask =
|
||||||
|
@ -160,23 +174,24 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
}
|
}
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(
|
result_callback(
|
||||||
{{confidence_masks.Get<std::vector<Image>>(), category_mask}},
|
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||||
confidence_masks.Timestamp().Value() /
|
|
||||||
kMicroSecondsPerMilliSecond);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
auto image_segmenter =
|
auto image_segmenter =
|
||||||
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||||
ImageSegmenterGraphOptionsProto>(
|
ImageSegmenterGraphOptionsProto>(
|
||||||
CreateGraphConfig(
|
CreateGraphConfig(
|
||||||
std::move(options_proto), options->output_category_mask,
|
std::move(options_proto), options->output_confidence_masks,
|
||||||
|
options->output_category_mask,
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
if (!image_segmenter.ok()) {
|
if (!image_segmenter.ok()) {
|
||||||
return image_segmenter.status();
|
return image_segmenter.status();
|
||||||
}
|
}
|
||||||
|
image_segmenter.value()->output_confidence_masks_ =
|
||||||
|
options->output_confidence_masks;
|
||||||
image_segmenter.value()->output_category_mask_ =
|
image_segmenter.value()->output_category_mask_ =
|
||||||
options->output_category_mask;
|
options->output_category_mask;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
|
@ -203,8 +218,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
||||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
std::vector<Image> confidence_masks =
|
std::optional<std::vector<Image>> confidence_masks;
|
||||||
|
if (output_confidence_masks_) {
|
||||||
|
confidence_masks =
|
||||||
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
@ -233,8 +251,11 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
std::vector<Image> confidence_masks =
|
std::optional<std::vector<Image>> confidence_masks;
|
||||||
|
if (output_confidence_masks_) {
|
||||||
|
confidence_masks =
|
||||||
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
|
|
@ -53,6 +53,9 @@ struct ImageSegmenterOptions {
|
||||||
// Metadata, if any. Defaults to English.
|
// Metadata, if any. Defaults to English.
|
||||||
std::string display_names_locale = "en";
|
std::string display_names_locale = "en";
|
||||||
|
|
||||||
|
// Whether to output confidence masks.
|
||||||
|
bool output_confidence_masks = true;
|
||||||
|
|
||||||
// Whether to output category mask.
|
// Whether to output category mask.
|
||||||
bool output_category_mask = false;
|
bool output_category_mask = false;
|
||||||
|
|
||||||
|
@ -77,8 +80,10 @@ struct ImageSegmenterOptions {
|
||||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||||
// attached to the metadata for input normalization.
|
// attached to the metadata for input normalization.
|
||||||
// Output ImageSegmenterResult:
|
// Output ImageSegmenterResult:
|
||||||
// Provides confidence masks and an optional category mask if
|
// Provides optional confidence masks if `output_confidence_masks` is set
|
||||||
// `output_category_mask` is set true.
|
// true, and an optional category mask if `output_category_mask` is set
|
||||||
|
// true. At least one of `output_confidence_masks` and `output_category_mask`
|
||||||
|
// must be set to true.
|
||||||
// An example of such model can be found at:
|
// An example of such model can be found at:
|
||||||
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
||||||
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
@ -167,6 +172,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> labels_;
|
std::vector<std::string> labels_;
|
||||||
|
bool output_confidence_masks_;
|
||||||
bool output_category_mask_;
|
bool output_category_mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -326,8 +326,10 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
}
|
}
|
||||||
|
|
||||||
// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
|
// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
|
||||||
// semantic segmentation. The graph always output confidence masks, and an
|
// semantic segmentation. The graph can output optional confidence masks if
|
||||||
// optional category mask if CATEGORY_MASK is connected.
|
// CONFIDENCE_MASKS is connected, and an optional category mask if CATEGORY_MASK
|
||||||
|
// is connected. At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and
|
||||||
|
// CATEGORY_MASK must be connected.
|
||||||
//
|
//
|
||||||
// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
|
// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
|
||||||
// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
|
// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
|
||||||
|
@ -347,7 +349,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
// CONFIDENCE_MASK - mediapipe::Image @Multiple
|
// CONFIDENCE_MASK - mediapipe::Image @Multiple
|
||||||
// Confidence masks for individual category. Confidence mask of single
|
// Confidence masks for individual category. Confidence mask of single
|
||||||
// category can be accessed by index based output stream.
|
// category can be accessed by index based output stream.
|
||||||
// CONFIDENCE_MASKS - std::vector<mediapipe::Image>
|
// CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
|
||||||
// The output confidence masks grouped in a vector.
|
// The output confidence masks grouped in a vector.
|
||||||
// CATEGORY_MASK - mediapipe::Image @Optional
|
// CATEGORY_MASK - mediapipe::Image @Optional
|
||||||
// Optional Category mask.
|
// Optional Category mask.
|
||||||
|
@ -356,7 +358,7 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.vision.ImageSegmenterGraph"
|
// calculator: "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"
|
||||||
// input_stream: "IMAGE:image"
|
// input_stream: "IMAGE:image"
|
||||||
// output_stream: "SEGMENTATION:segmented_masks"
|
// output_stream: "SEGMENTATION:segmented_masks"
|
||||||
// options {
|
// options {
|
||||||
|
@ -382,17 +384,20 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
CreateModelResources<ImageSegmenterGraphOptions>(sc));
|
CreateModelResources<ImageSegmenterGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
const auto& options = sc->Options<ImageSegmenterGraphOptions>();
|
const auto& options = sc->Options<ImageSegmenterGraphOptions>();
|
||||||
|
// TODO: remove deprecated output type support.
|
||||||
|
if (!options.segmenter_options().has_output_type()) {
|
||||||
|
MP_RETURN_IF_ERROR(SanityCheck(sc));
|
||||||
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto output_streams,
|
||||||
BuildSegmentationTask(
|
BuildSegmentationTask(
|
||||||
options, *model_resources, graph[Input<Image>(kImageTag)],
|
options, *model_resources, graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph));
|
|
||||||
|
|
||||||
auto& merge_images_to_vector =
|
|
||||||
graph.AddNode("MergeImagesToVectorCalculator");
|
|
||||||
// TODO: remove deprecated output type support.
|
// TODO: remove deprecated output type support.
|
||||||
if (options.segmenter_options().has_output_type()) {
|
if (options.segmenter_options().has_output_type()) {
|
||||||
|
auto& merge_images_to_vector =
|
||||||
|
graph.AddNode("MergeImagesToVectorCalculator");
|
||||||
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
|
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
|
||||||
output_streams.segmented_masks->at(i) >>
|
output_streams.segmented_masks->at(i) >>
|
||||||
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
||||||
|
@ -402,6 +407,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
merge_images_to_vector.Out("") >>
|
merge_images_to_vector.Out("") >>
|
||||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||||
} else {
|
} else {
|
||||||
|
if (output_streams.confidence_masks) {
|
||||||
|
auto& merge_images_to_vector =
|
||||||
|
graph.AddNode("MergeImagesToVectorCalculator");
|
||||||
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
|
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
|
||||||
output_streams.confidence_masks->at(i) >>
|
output_streams.confidence_masks->at(i) >>
|
||||||
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
||||||
|
@ -409,7 +417,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
|
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
|
||||||
}
|
}
|
||||||
merge_images_to_vector.Out("") >>
|
merge_images_to_vector.Out("") >>
|
||||||
graph[Output<std::vector<Image>>(kConfidenceMasksTag)];
|
graph[Output<std::vector<Image>>::Optional(kConfidenceMasksTag)];
|
||||||
|
}
|
||||||
if (output_streams.category_mask) {
|
if (output_streams.category_mask) {
|
||||||
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
|
@ -419,6 +428,19 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
absl::Status SanityCheck(mediapipe::SubgraphContext* sc) {
|
||||||
|
const auto& node = sc->OriginalNode();
|
||||||
|
output_confidence_masks_ = HasOutput(node, kConfidenceMaskTag) ||
|
||||||
|
HasOutput(node, kConfidenceMasksTag);
|
||||||
|
output_category_mask_ = HasOutput(node, kCategoryMaskTag);
|
||||||
|
if (!output_confidence_masks_ && !output_category_mask_) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"At least one of CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK "
|
||||||
|
"must be connected.");
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
// Adds a mediapipe image segmentation task pipeline graph into the provided
|
// Adds a mediapipe image segmentation task pipeline graph into the provided
|
||||||
// builder::Graph instance. The segmentation pipeline takes images
|
// builder::Graph instance. The segmentation pipeline takes images
|
||||||
// (mediapipe::Image) as the input and returns segmented image mask as output.
|
// (mediapipe::Image) as the input and returns segmented image mask as output.
|
||||||
|
@ -431,8 +453,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||||
const ImageSegmenterGraphOptions& task_options,
|
const ImageSegmenterGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, bool output_category_mask,
|
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||||
Graph& graph) {
|
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||||
|
|
||||||
// Adds preprocessing calculators and connects them to the graph input image
|
// Adds preprocessing calculators and connects them to the graph input image
|
||||||
|
@ -485,26 +506,32 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
/*category_mask=*/std::nullopt,
|
/*category_mask=*/std::nullopt,
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
} else {
|
} else {
|
||||||
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
|
if (output_confidence_masks_) {
|
||||||
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
||||||
GetOutputTensor(model_resources));
|
GetOutputTensor(model_resources));
|
||||||
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||||
std::vector<Source<Image>> confidence_masks;
|
confidence_masks = std::vector<Source<Image>>();
|
||||||
confidence_masks.reserve(segmentation_streams_num);
|
confidence_masks->reserve(segmentation_streams_num);
|
||||||
for (int i = 0; i < segmentation_streams_num; ++i) {
|
for (int i = 0; i < segmentation_streams_num; ++i) {
|
||||||
confidence_masks.push_back(Source<Image>(
|
confidence_masks->push_back(Source<Image>(
|
||||||
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)]
|
||||||
|
[i]));
|
||||||
}
|
}
|
||||||
return ImageSegmenterOutputs{
|
}
|
||||||
/*segmented_masks=*/std::nullopt,
|
std::optional<Source<Image>> category_mask;
|
||||||
|
if (output_category_mask_) {
|
||||||
|
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
||||||
|
}
|
||||||
|
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
||||||
/*confidence_masks=*/confidence_masks,
|
/*confidence_masks=*/confidence_masks,
|
||||||
/*category_mask=*/
|
/*category_mask=*/category_mask,
|
||||||
output_category_mask
|
|
||||||
? std::make_optional(
|
|
||||||
tensor_to_images[Output<Image>(kCategoryMaskTag)])
|
|
||||||
: std::nullopt,
|
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool output_confidence_masks_ = false;
|
||||||
|
bool output_category_mask_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace image_segmenter {
|
||||||
struct ImageSegmenterResult {
|
struct ImageSegmenterResult {
|
||||||
// Multiple masks of float image in VEC32F1 format where, for each mask, each
|
// Multiple masks of float image in VEC32F1 format where, for each mask, each
|
||||||
// pixel represents the prediction confidence, usually in the [0, 1] range.
|
// pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
std::vector<Image> confidence_masks;
|
std::optional<std::vector<Image>> confidence_masks;
|
||||||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||||
// the class which the pixel in the original image was predicted to belong to.
|
// the class which the pixel in the original image was predicted to belong to.
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
|
|
|
@ -278,6 +278,7 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
|
options->output_confidence_masks = false;
|
||||||
options->output_category_mask = true;
|
options->output_category_mask = true;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
@ -306,7 +307,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 21);
|
EXPECT_EQ(result.confidence_masks->size(), 21);
|
||||||
|
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
|
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
|
||||||
|
@ -315,7 +316,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
|
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[8].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(8).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(cat_mask,
|
EXPECT_THAT(cat_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -336,7 +337,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
image_processing_options.rotation_degrees = -90;
|
image_processing_options.rotation_degrees = -90;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||||
segmenter->Segment(image, image_processing_options));
|
segmenter->Segment(image, image_processing_options));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 21);
|
EXPECT_EQ(result.confidence_masks->size(), 21);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
||||||
|
@ -346,7 +347,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
|
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[8].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(8).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(cat_mask,
|
EXPECT_THAT(cat_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -384,7 +385,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 2);
|
EXPECT_EQ(result.confidence_masks->size(), 2);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory,
|
cv::imread(JoinPath("./", kTestDataDirectory,
|
||||||
|
@ -395,7 +396,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||||
|
|
||||||
// Selfie category index 1.
|
// Selfie category index 1.
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[1].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -409,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 1);
|
EXPECT_EQ(result.confidence_masks->size(), 1);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory,
|
cv::imread(JoinPath("./", kTestDataDirectory,
|
||||||
|
@ -419,7 +420,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[0].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(0).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -434,7 +435,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 1);
|
EXPECT_EQ(result.confidence_masks->size(), 1);
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
|
@ -445,7 +446,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[0].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(0).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -506,10 +507,10 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(result.confidence_masks.size(), 2);
|
EXPECT_EQ(result.confidence_masks->size(), 2);
|
||||||
|
|
||||||
cv::Mat hair_mask = mediapipe::formats::MatView(
|
cv::Mat hair_mask = mediapipe::formats::MatView(
|
||||||
result.confidence_masks[1].GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
|
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user