update ImageSegmenterGraph to always output confidence mask and optionally output category mask
PiperOrigin-RevId: 521679910
This commit is contained in:
parent
c31a4681e5
commit
367ccbfdf3
|
@ -16,6 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "image_segmenter_result",
|
||||||
|
hdrs = ["image_segmenter_result.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = ["//mediapipe/framework/formats:image"],
|
||||||
|
)
|
||||||
|
|
||||||
# Docs for Mediapipe Tasks Image Segmenter
|
# Docs for Mediapipe Tasks Image Segmenter
|
||||||
# https://developers.google.com/mediapipe/solutions/vision/image_segmenter
|
# https://developers.google.com/mediapipe/solutions/vision/image_segmenter
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -25,6 +32,7 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":image_segmenter_graph",
|
":image_segmenter_graph",
|
||||||
|
":image_segmenter_result",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
@ -82,6 +90,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
|
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"//mediapipe/util:label_map_util",
|
"//mediapipe/util:label_map_util",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
|
|
@ -80,10 +80,10 @@ void Sigmoid(absl::Span<const float> values,
|
||||||
[](float value) { return 1. / (1 + std::exp(-value)); });
|
[](float value) { return 1. / (1 + std::exp(-value)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
|
Image ProcessForCategoryMaskCpu(const Shape& input_shape,
|
||||||
const Shape& output_shape,
|
const Shape& output_shape,
|
||||||
const SegmenterOptions& options,
|
const SegmenterOptions& options,
|
||||||
const float* tensors_buffer) {
|
const float* tensors_buffer) {
|
||||||
cv::Mat resized_tensors_mat;
|
cv::Mat resized_tensors_mat;
|
||||||
cv::Mat tensors_mat_view(
|
cv::Mat tensors_mat_view(
|
||||||
input_shape.height, input_shape.width, CV_32FC(input_shape.channels),
|
input_shape.height, input_shape.width, CV_32FC(input_shape.channels),
|
||||||
|
@ -135,7 +135,7 @@ std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
|
||||||
pixel = maximum_category_idx;
|
pixel = maximum_category_idx;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return {category_mask};
|
return category_mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||||
|
@ -209,7 +209,9 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Converts Tensors from a vector of Tensor to Segmentation.
|
// Converts Tensors from a vector of Tensor to Segmentation masks. The
|
||||||
|
// calculator always output confidence masks, and an optional category mask if
|
||||||
|
// CATEGORY_MASK is connected.
|
||||||
//
|
//
|
||||||
// Performs optional resizing to OUTPUT_SIZE dimension if provided,
|
// Performs optional resizing to OUTPUT_SIZE dimension if provided,
|
||||||
// otherwise the segmented masks is the same size as input tensor.
|
// otherwise the segmented masks is the same size as input tensor.
|
||||||
|
@ -221,7 +223,12 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||||
// the size to resize masks to.
|
// the size to resize masks to.
|
||||||
//
|
//
|
||||||
// Output:
|
// Output:
|
||||||
// Segmentation: Segmentation proto.
|
// CONFIDENCE_MASK @Multiple: Multiple masks of float image where, for each
|
||||||
|
// mask, each pixel represents the prediction confidence, usually in the [0,
|
||||||
|
// 1] range.
|
||||||
|
// CATEGORY_MASK @Optional: A category mask of uint8 image where each pixel
|
||||||
|
// represents the class which the pixel in the original image was predicted to
|
||||||
|
// belong to.
|
||||||
//
|
//
|
||||||
// Options:
|
// Options:
|
||||||
// See tensors_to_segmentation_calculator.proto
|
// See tensors_to_segmentation_calculator.proto
|
||||||
|
@ -231,13 +238,13 @@ std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||||
// calculator: "TensorsToSegmentationCalculator"
|
// calculator: "TensorsToSegmentationCalculator"
|
||||||
// input_stream: "TENSORS:tensors"
|
// input_stream: "TENSORS:tensors"
|
||||||
// input_stream: "OUTPUT_SIZE:size"
|
// input_stream: "OUTPUT_SIZE:size"
|
||||||
// output_stream: "SEGMENTATION:0:segmentation"
|
// output_stream: "CONFIDENCE_MASK:0:confidence_mask"
|
||||||
// output_stream: "SEGMENTATION:1:segmentation"
|
// output_stream: "CONFIDENCE_MASK:1:confidence_mask"
|
||||||
|
// output_stream: "CATEGORY_MASK:category_mask"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
// [mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
// segmenter_options {
|
// segmenter_options {
|
||||||
// activation: SOFTMAX
|
// activation: SOFTMAX
|
||||||
// output_type: CONFIDENCE_MASK
|
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
@ -248,7 +255,11 @@ class TensorsToSegmentationCalculator : public Node {
|
||||||
static constexpr Input<std::pair<int, int>>::Optional kOutputSizeIn{
|
static constexpr Input<std::pair<int, int>>::Optional kOutputSizeIn{
|
||||||
"OUTPUT_SIZE"};
|
"OUTPUT_SIZE"};
|
||||||
static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"};
|
static constexpr Output<Image>::Multiple kSegmentationOut{"SEGMENTATION"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut);
|
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
||||||
|
"CONFIDENCE_MASK"};
|
||||||
|
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
||||||
|
kConfidenceMaskOut, kCategoryMaskOut);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
@ -279,9 +290,13 @@ absl::Status TensorsToSegmentationCalculator::UpdateContract(
|
||||||
absl::Status TensorsToSegmentationCalculator::Open(
|
absl::Status TensorsToSegmentationCalculator::Open(
|
||||||
mediapipe::CalculatorContext* cc) {
|
mediapipe::CalculatorContext* cc) {
|
||||||
options_ = cc->Options<TensorsToSegmentationCalculatorOptions>();
|
options_ = cc->Options<TensorsToSegmentationCalculatorOptions>();
|
||||||
RET_CHECK_NE(options_.segmenter_options().output_type(),
|
// TODO: remove deprecated output type support.
|
||||||
SegmenterOptions::UNSPECIFIED)
|
if (options_.segmenter_options().has_output_type()) {
|
||||||
<< "Must specify output_type as one of [CONFIDENCE_MASK|CATEGORY_MASK].";
|
RET_CHECK_NE(options_.segmenter_options().output_type(),
|
||||||
|
SegmenterOptions::UNSPECIFIED)
|
||||||
|
<< "Must specify output_type as one of "
|
||||||
|
"[CONFIDENCE_MASK|CATEGORY_MASK].";
|
||||||
|
}
|
||||||
#ifdef __EMSCRIPTEN__
|
#ifdef __EMSCRIPTEN__
|
||||||
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
|
MP_RETURN_IF_ERROR(postprocessor_.Initialize(cc, options_));
|
||||||
#endif // __EMSCRIPTEN__
|
#endif // __EMSCRIPTEN__
|
||||||
|
@ -309,6 +324,10 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
if (cc->Inputs().HasTag("OUTPUT_SIZE")) {
|
if (cc->Inputs().HasTag("OUTPUT_SIZE")) {
|
||||||
std::tie(output_width, output_height) = kOutputSizeIn(cc).Get();
|
std::tie(output_width, output_height) = kOutputSizeIn(cc).Get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use GPU postprocessing on web when Tensor is there already and has <= 12
|
||||||
|
// categories.
|
||||||
|
#ifdef __EMSCRIPTEN__
|
||||||
Shape output_shape = {
|
Shape output_shape = {
|
||||||
/* height= */ output_height,
|
/* height= */ output_height,
|
||||||
/* width= */ output_width,
|
/* width= */ output_width,
|
||||||
|
@ -316,10 +335,6 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
SegmenterOptions::CATEGORY_MASK
|
SegmenterOptions::CATEGORY_MASK
|
||||||
? 1
|
? 1
|
||||||
: input_shape.channels};
|
: input_shape.channels};
|
||||||
|
|
||||||
// Use GPU postprocessing on web when Tensor is there already and has <= 12
|
|
||||||
// categories.
|
|
||||||
#ifdef __EMSCRIPTEN__
|
|
||||||
if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) {
|
if (input_tensor.ready_as_opengl_texture_2d() && input_shape.channels <= 12) {
|
||||||
std::vector<std::unique_ptr<Image>> segmented_masks =
|
std::vector<std::unique_ptr<Image>> segmented_masks =
|
||||||
postprocessor_.GetSegmentationResultGpu(input_shape, output_shape,
|
postprocessor_.GetSegmentationResultGpu(input_shape, output_shape,
|
||||||
|
@ -332,10 +347,41 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
#endif // __EMSCRIPTEN__
|
#endif // __EMSCRIPTEN__
|
||||||
|
|
||||||
// Otherwise, use CPU postprocessing.
|
// Otherwise, use CPU postprocessing.
|
||||||
std::vector<Image> segmented_masks = GetSegmentationResultCpu(
|
const float* tensors_buffer = input_tensor.GetCpuReadView().buffer<float>();
|
||||||
input_shape, output_shape, input_tensor.GetCpuReadView().buffer<float>());
|
|
||||||
for (int i = 0; i < segmented_masks.size(); ++i) {
|
// TODO: remove deprecated output type support.
|
||||||
kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
|
if (options_.segmenter_options().has_output_type()) {
|
||||||
|
std::vector<Image> segmented_masks = GetSegmentationResultCpu(
|
||||||
|
input_shape,
|
||||||
|
{/* height= */ output_height,
|
||||||
|
/* width= */ output_width,
|
||||||
|
/* channels= */ options_.segmenter_options().output_type() ==
|
||||||
|
SegmenterOptions::CATEGORY_MASK
|
||||||
|
? 1
|
||||||
|
: input_shape.channels},
|
||||||
|
input_tensor.GetCpuReadView().buffer<float>());
|
||||||
|
for (int i = 0; i < segmented_masks.size(); ++i) {
|
||||||
|
kSegmentationOut(cc)[i].Send(std::move(segmented_masks[i]));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Image> confidence_masks =
|
||||||
|
ProcessForConfidenceMaskCpu(input_shape,
|
||||||
|
{/* height= */ output_height,
|
||||||
|
/* width= */ output_width,
|
||||||
|
/* channels= */ input_shape.channels},
|
||||||
|
options_.segmenter_options(), tensors_buffer);
|
||||||
|
for (int i = 0; i < confidence_masks.size(); ++i) {
|
||||||
|
kConfidenceMaskOut(cc)[i].Send(std::move(confidence_masks[i]));
|
||||||
|
}
|
||||||
|
if (cc->Outputs().HasTag("CATEGORY_MASK")) {
|
||||||
|
kCategoryMaskOut(cc).Send(ProcessForCategoryMaskCpu(
|
||||||
|
input_shape,
|
||||||
|
{/* height= */ output_height,
|
||||||
|
/* width= */ output_width,
|
||||||
|
/* channels= */ 1},
|
||||||
|
options_.segmenter_options(), tensors_buffer));
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -345,9 +391,9 @@ std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
|
||||||
const float* tensors_buffer) {
|
const float* tensors_buffer) {
|
||||||
if (options_.segmenter_options().output_type() ==
|
if (options_.segmenter_options().output_type() ==
|
||||||
SegmenterOptions::CATEGORY_MASK) {
|
SegmenterOptions::CATEGORY_MASK) {
|
||||||
return ProcessForCategoryMaskCpu(input_shape, output_shape,
|
return {ProcessForCategoryMaskCpu(input_shape, output_shape,
|
||||||
options_.segmenter_options(),
|
options_.segmenter_options(),
|
||||||
tensors_buffer);
|
tensors_buffer)};
|
||||||
} else {
|
} else {
|
||||||
return ProcessForConfidenceMaskCpu(input_shape, output_shape,
|
return ProcessForConfidenceMaskCpu(input_shape, output_shape,
|
||||||
options_.segmenter_options(),
|
options_.segmenter_options(),
|
||||||
|
|
|
@ -79,8 +79,9 @@ void PushTensorsToRunner(int tensor_height, int tensor_width,
|
||||||
std::vector<Packet> GetPackets(const CalculatorRunner& runner) {
|
std::vector<Packet> GetPackets(const CalculatorRunner& runner) {
|
||||||
std::vector<Packet> mask_packets;
|
std::vector<Packet> mask_packets;
|
||||||
for (int i = 0; i < runner.Outputs().NumEntries(); ++i) {
|
for (int i = 0; i < runner.Outputs().NumEntries(); ++i) {
|
||||||
EXPECT_EQ(runner.Outputs().Get("SEGMENTATION", i).packets.size(), 1);
|
EXPECT_EQ(runner.Outputs().Get("CONFIDENCE_MASK", i).packets.size(), 1);
|
||||||
mask_packets.push_back(runner.Outputs().Get("SEGMENTATION", i).packets[0]);
|
mask_packets.push_back(
|
||||||
|
runner.Outputs().Get("CONFIDENCE_MASK", i).packets[0]);
|
||||||
}
|
}
|
||||||
return mask_packets;
|
return mask_packets;
|
||||||
}
|
}
|
||||||
|
@ -118,13 +119,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "CONFIDENCE_MASK:segmentation"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: SOFTMAX }
|
||||||
activation: SOFTMAX
|
|
||||||
output_type: CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -145,13 +143,10 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "CONFIDENCE_MASK:segmentation"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: SOFTMAX }
|
||||||
activation: SOFTMAX
|
|
||||||
output_type: CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -173,16 +168,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
|
||||||
output_stream: "SEGMENTATION:2:segmented_mask_2"
|
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
|
||||||
output_stream: "SEGMENTATION:3:segmented_mask_3"
|
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: SOFTMAX }
|
||||||
activation: SOFTMAX
|
|
||||||
output_type: CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -218,16 +210,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
|
||||||
output_stream: "SEGMENTATION:2:segmented_mask_2"
|
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
|
||||||
output_stream: "SEGMENTATION:3:segmented_mask_3"
|
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: NONE }
|
||||||
activation: NONE
|
|
||||||
output_type: CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -259,16 +248,13 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:0:segmented_mask_0"
|
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
|
||||||
output_stream: "SEGMENTATION:1:segmented_mask_1"
|
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
|
||||||
output_stream: "SEGMENTATION:2:segmented_mask_2"
|
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
|
||||||
output_stream: "SEGMENTATION:3:segmented_mask_3"
|
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: SIGMOID }
|
||||||
activation: SIGMOID
|
|
||||||
output_type: CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -301,13 +287,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
|
||||||
R"pb(
|
R"pb(
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
|
||||||
|
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
|
||||||
|
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
|
||||||
|
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
|
||||||
|
output_stream: "CATEGORY_MASK:segmentation"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: NONE }
|
||||||
activation: NONE
|
|
||||||
output_type: CATEGORY_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -318,11 +305,11 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) {
|
||||||
tensor_height, tensor_width,
|
tensor_height, tensor_width,
|
||||||
std::vector<float>(kTestValues.begin(), kTestValues.end()), &runner);
|
std::vector<float>(kTestValues.begin(), kTestValues.end()), &runner);
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
ASSERT_EQ(runner.Outputs().NumEntries(), 1);
|
ASSERT_EQ(runner.Outputs().NumEntries(), 5);
|
||||||
// Largest element index is 3.
|
// Largest element index is 3.
|
||||||
const int expected_index = 3;
|
const int expected_index = 3;
|
||||||
const std::vector<int> buffer_indices = {0};
|
const std::vector<int> buffer_indices = {0};
|
||||||
std::vector<Packet> packets = GetPackets(runner);
|
std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
|
||||||
EXPECT_THAT(packets, testing::ElementsAre(
|
EXPECT_THAT(packets, testing::ElementsAre(
|
||||||
Uint8ImagePacket(tensor_height, tensor_width,
|
Uint8ImagePacket(tensor_height, tensor_width,
|
||||||
expected_index, buffer_indices)));
|
expected_index, buffer_indices)));
|
||||||
|
@ -335,13 +322,14 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
|
||||||
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
calculator: "mediapipe.tasks.TensorsToSegmentationCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
input_stream: "OUTPUT_SIZE:size"
|
input_stream: "OUTPUT_SIZE:size"
|
||||||
output_stream: "SEGMENTATION:segmentation"
|
output_stream: "CONFIDENCE_MASK:0:segmented_mask_0"
|
||||||
|
output_stream: "CONFIDENCE_MASK:1:segmented_mask_1"
|
||||||
|
output_stream: "CONFIDENCE_MASK:2:segmented_mask_2"
|
||||||
|
output_stream: "CONFIDENCE_MASK:3:segmented_mask_3"
|
||||||
|
output_stream: "CATEGORY_MASK:segmentation"
|
||||||
options {
|
options {
|
||||||
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
[mediapipe.tasks.TensorsToSegmentationCalculatorOptions.ext] {
|
||||||
segmenter_options {
|
segmenter_options { activation: NONE }
|
||||||
activation: NONE
|
|
||||||
output_type: CATEGORY_MASK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
@ -367,7 +355,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) {
|
||||||
const std::vector<int> buffer_indices = {
|
const std::vector<int> buffer_indices = {
|
||||||
0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0,
|
0 * output_width + 0, 0 * output_width + 1, 1 * output_width + 0,
|
||||||
1 * output_width + 1};
|
1 * output_width + 1};
|
||||||
std::vector<Packet> packets = GetPackets(runner);
|
std::vector<Packet> packets = runner.Outputs().Tag("CATEGORY_MASK").packets;
|
||||||
EXPECT_THAT(packets, testing::ElementsAre(
|
EXPECT_THAT(packets, testing::ElementsAre(
|
||||||
Uint8ImagePacket(output_height, output_width,
|
Uint8ImagePacket(output_height, output_width,
|
||||||
expected_index, buffer_indices)));
|
expected_index, buffer_indices)));
|
||||||
|
|
|
@ -37,8 +37,10 @@ namespace vision {
|
||||||
namespace image_segmenter {
|
namespace image_segmenter {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
|
||||||
|
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
|
||||||
|
constexpr char kCategoryMaskStreamName[] = "category_mask";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -51,7 +53,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
|
||||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||||
|
|
||||||
|
@ -59,21 +60,24 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
||||||
bool enable_flow_limiting) {
|
bool output_category_mask, bool enable_flow_limiting) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||||
options.get());
|
options.get());
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
task_subgraph.Out(kConfidenceMasksTag).SetName(kConfidenceMasksStreamName) >>
|
||||||
graph.Out(kGroupedSegmentationTag);
|
graph.Out(kConfidenceMasksTag);
|
||||||
|
if (output_category_mask) {
|
||||||
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
|
graph.Out(kCategoryMaskTag);
|
||||||
|
}
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
{kImageTag, kNormRectTag},
|
graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag);
|
||||||
kGroupedSegmentationTag);
|
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||||
|
@ -91,16 +95,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode != core::RunningMode::IMAGE);
|
options->running_mode != core::RunningMode::IMAGE);
|
||||||
options_proto->set_display_names_locale(options->display_names_locale);
|
options_proto->set_display_names_locale(options->display_names_locale);
|
||||||
switch (options->output_type) {
|
|
||||||
case ImageSegmenterOptions::OutputType::CATEGORY_MASK:
|
|
||||||
options_proto->mutable_segmenter_options()->set_output_type(
|
|
||||||
SegmenterOptions::CATEGORY_MASK);
|
|
||||||
break;
|
|
||||||
case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
|
||||||
options_proto->mutable_segmenter_options()->set_output_type(
|
|
||||||
SegmenterOptions::CONFIDENCE_MASK);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,6 +139,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||||
if (options->result_callback) {
|
if (options->result_callback) {
|
||||||
auto result_callback = options->result_callback;
|
auto result_callback = options->result_callback;
|
||||||
|
bool output_category_mask = options->output_category_mask;
|
||||||
packets_callback =
|
packets_callback =
|
||||||
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
if (!status_or_packets.ok()) {
|
if (!status_or_packets.ok()) {
|
||||||
|
@ -156,34 +151,41 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Packet segmented_masks =
|
Packet confidence_masks =
|
||||||
status_or_packets.value()[kSegmentationStreamName];
|
status_or_packets.value()[kConfidenceMasksStreamName];
|
||||||
|
std::optional<Image> category_mask;
|
||||||
|
if (output_category_mask) {
|
||||||
|
category_mask =
|
||||||
|
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
}
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(segmented_masks.Get<std::vector<Image>>(),
|
result_callback(
|
||||||
image_packet.Get<Image>(),
|
{{confidence_masks.Get<std::vector<Image>>(), category_mask}},
|
||||||
segmented_masks.Timestamp().Value() /
|
image_packet.Get<Image>(),
|
||||||
kMicroSecondsPerMilliSecond);
|
confidence_masks.Timestamp().Value() /
|
||||||
|
kMicroSecondsPerMilliSecond);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto image_segmenter =
|
auto image_segmenter =
|
||||||
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||||
ImageSegmenterGraphOptionsProto>(
|
ImageSegmenterGraphOptionsProto>(
|
||||||
CreateGraphConfig(
|
CreateGraphConfig(
|
||||||
std::move(options_proto),
|
std::move(options_proto), options->output_category_mask,
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
if (!image_segmenter.ok()) {
|
if (!image_segmenter.ok()) {
|
||||||
return image_segmenter.status();
|
return image_segmenter.status();
|
||||||
}
|
}
|
||||||
|
image_segmenter.value()->output_category_mask_ =
|
||||||
|
options->output_category_mask;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
(*image_segmenter)->labels_,
|
(*image_segmenter)->labels_,
|
||||||
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
|
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
|
||||||
return image_segmenter;
|
return image_segmenter;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -201,11 +203,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
std::vector<Image> confidence_masks =
|
||||||
|
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||||
|
std::optional<Image> category_mask;
|
||||||
|
if (output_category_mask_) {
|
||||||
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
}
|
||||||
|
return {{confidence_masks, category_mask}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
|
@ -225,11 +233,17 @@ absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
std::vector<Image> confidence_masks =
|
||||||
|
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||||
|
std::optional<Image> category_mask;
|
||||||
|
if (output_category_mask_) {
|
||||||
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
|
}
|
||||||
|
return {{confidence_masks, category_mask}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageSegmenter::SegmentAsync(
|
absl::Status ImageSegmenter::SegmentAsync(
|
||||||
Image image, int64 timestamp_ms,
|
Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -52,23 +53,14 @@ struct ImageSegmenterOptions {
|
||||||
// Metadata, if any. Defaults to English.
|
// Metadata, if any. Defaults to English.
|
||||||
std::string display_names_locale = "en";
|
std::string display_names_locale = "en";
|
||||||
|
|
||||||
// The output type of segmentation results.
|
// Whether to output category mask.
|
||||||
enum OutputType {
|
bool output_category_mask = false;
|
||||||
// Gives a single output mask where each pixel represents the class which
|
|
||||||
// the pixel in the original image was predicted to belong to.
|
|
||||||
CATEGORY_MASK = 0,
|
|
||||||
// Gives a list of output masks where, for each mask, each pixel represents
|
|
||||||
// the prediction confidence, usually in the [0, 1] range.
|
|
||||||
CONFIDENCE_MASK = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
|
||||||
|
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(absl::StatusOr<std::vector<mediapipe::Image>>,
|
std::function<void(absl::StatusOr<ImageSegmenterResult>, const Image&,
|
||||||
const Image&, int64)>
|
int64_t)>
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -84,13 +76,9 @@ struct ImageSegmenterOptions {
|
||||||
// 1 or 3).
|
// 1 or 3).
|
||||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||||
// attached to the metadata for input normalization.
|
// attached to the metadata for input normalization.
|
||||||
// Output tensors:
|
// Output ImageSegmenterResult:
|
||||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
// Provides confidence masks and an optional category mask if
|
||||||
// - list of segmented masks.
|
// `output_category_mask` is set true.
|
||||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
|
||||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
|
||||||
// `channels`.
|
|
||||||
// - batch is always 1
|
|
||||||
// An example of such model can be found at:
|
// An example of such model can be found at:
|
||||||
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
||||||
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
@ -114,12 +102,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// setting its 'rotation_degrees' field. Note that specifying a
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
// and will result in an invalid argument error being returned.
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
absl::StatusOr<ImageSegmenterResult> Segment(
|
||||||
// per-category segmented image mask.
|
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
|
||||||
// contains only one confidence image mask.
|
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -137,13 +121,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// setting its 'rotation_degrees' field. Note that specifying a
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
// and will result in an invalid argument error being returned.
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
absl::StatusOr<ImageSegmenterResult> SegmentForVideo(
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
mediapipe::Image image, int64_t timestamp_ms,
|
||||||
// per-category segmented image mask.
|
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
|
||||||
// contains only one confidence image mask.
|
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
|
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
|
||||||
|
@ -164,17 +143,13 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// and will result in an invalid argument error being returned.
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
//
|
||||||
// The "result_callback" prvoides
|
// The "result_callback" prvoides
|
||||||
// - A vector of segmented image masks.
|
// - An ImageSegmenterResult.
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
|
||||||
// per-category segmented image mask.
|
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
|
||||||
// contains only one confidence image mask.
|
|
||||||
// - The const reference to the corresponding input image that the image
|
// - The const reference to the corresponding input image that the image
|
||||||
// segmentation runs on. Note that the const reference to the image will
|
// segmentation runs on. Note that the const reference to the image will
|
||||||
// no longer be valid when the callback returns. To access the image data
|
// no longer be valid when the callback returns. To access the image data
|
||||||
// outside of the callback, callers need to make a copy of the image.
|
// outside of the callback, callers need to make a copy of the image.
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
|
absl::Status SegmentAsync(mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions>
|
||||||
image_processing_options = std::nullopt);
|
image_processing_options = std::nullopt);
|
||||||
|
|
||||||
|
@ -182,9 +157,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
// Get the category label list of the ImageSegmenter can recognize. For
|
// Get the category label list of the ImageSegmenter can recognize. For
|
||||||
// CATEGORY_MASK type, the index in the category mask corresponds to the
|
// CATEGORY_MASK, the index in the category mask corresponds to the category
|
||||||
// category in the label list. For CONFIDENCE_MASK type, the output mask list
|
// in the label list. For CONFIDENCE_MASK, the output mask list at index
|
||||||
// at index corresponds to the category in the label list.
|
// corresponds to the category in the label list.
|
||||||
//
|
//
|
||||||
// If there is no labelmap provided in the model file, empty label list is
|
// If there is no labelmap provided in the model file, empty label list is
|
||||||
// returned.
|
// returned.
|
||||||
|
@ -192,6 +167,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> labels_;
|
std::vector<std::string> labels_;
|
||||||
|
bool output_category_mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_segmenter
|
} // namespace image_segmenter
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||||
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
#include "mediapipe/util/graph_builder_utils.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
#include "mediapipe/util/label_map_util.h"
|
#include "mediapipe/util/label_map_util.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
@ -65,10 +67,13 @@ using ::mediapipe::tasks::vision::image_segmenter::proto::
|
||||||
ImageSegmenterGraphOptions;
|
ImageSegmenterGraphOptions;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||||
using ::tflite::TensorMetadata;
|
using ::tflite::TensorMetadata;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
|
||||||
|
|
||||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
|
constexpr char kConfidenceMaskTag[] = "CONFIDENCE_MASK";
|
||||||
|
constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS";
|
||||||
|
constexpr char kCategoryMaskTag[] = "CATEGORY_MASK";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||||
|
@ -80,7 +85,9 @@ constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||||
// Struct holding the different output streams produced by the image segmenter
|
// Struct holding the different output streams produced by the image segmenter
|
||||||
// subgraph.
|
// subgraph.
|
||||||
struct ImageSegmenterOutputs {
|
struct ImageSegmenterOutputs {
|
||||||
std::vector<Source<Image>> segmented_masks;
|
std::optional<std::vector<Source<Image>>> segmented_masks;
|
||||||
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
|
std::optional<Source<Image>> category_mask;
|
||||||
// The same as the input image, mainly used for live stream mode.
|
// The same as the input image, mainly used for live stream mode.
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
@ -95,8 +102,10 @@ struct ImageAndTensorsOnDevice {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
|
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
|
||||||
if (options.segmenter_options().output_type() ==
|
// TODO: remove deprecated output type support.
|
||||||
SegmenterOptions::UNSPECIFIED) {
|
if (options.segmenter_options().has_output_type() &&
|
||||||
|
options.segmenter_options().output_type() ==
|
||||||
|
SegmenterOptions::UNSPECIFIED) {
|
||||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||||
"`output_type` must not be UNSPECIFIED",
|
"`output_type` must not be UNSPECIFIED",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
@ -133,9 +142,8 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
||||||
const core::ModelResources& model_resources,
|
const core::ModelResources& model_resources,
|
||||||
TensorsToSegmentationCalculatorOptions* options) {
|
TensorsToSegmentationCalculatorOptions* options) {
|
||||||
// Set default activation function NONE
|
// Set default activation function NONE
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
options->mutable_segmenter_options()->CopyFrom(
|
||||||
segmenter_option.segmenter_options().output_type());
|
segmenter_option.segmenter_options());
|
||||||
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
|
|
||||||
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
|
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
|
||||||
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||||
bool found_activation_in_metadata = false;
|
bool found_activation_in_metadata = false;
|
||||||
|
@ -317,12 +325,14 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
|
// An "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph" performs
|
||||||
// segmentation.
|
// semantic segmentation. The graph always output confidence masks, and an
|
||||||
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
|
// optional category mask if CATEGORY_MASK is connected.
|
||||||
// Users can retrieve segmented mask of only particular category/channel from
|
//
|
||||||
// SEGMENTATION, and users can also get all segmented masks from
|
// Two kinds of outputs for confidence mask are provided: CONFIDENCE_MASK and
|
||||||
// GROUPED_SEGMENTATION.
|
// CONFIDENCE_MASKS. Users can retrieve segmented mask of only particular
|
||||||
|
// category/channel from CONFIDENCE_MASK, and users can also get all segmented
|
||||||
|
// confidence masks from CONFIDENCE_MASKS.
|
||||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -334,11 +344,13 @@ absl::StatusOr<ImageAndTensorsOnDevice> ConvertImageToTensors(
|
||||||
// @Optional: rect covering the whole image is used if not specified.
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// SEGMENTATION - mediapipe::Image @Multiple
|
// CONFIDENCE_MASK - mediapipe::Image @Multiple
|
||||||
// Segmented masks for individual category. Segmented mask of single
|
// Confidence masks for individual category. Confidence mask of single
|
||||||
// category can be accessed by index based output stream.
|
// category can be accessed by index based output stream.
|
||||||
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
// CONFIDENCE_MASKS - std::vector<mediapipe::Image>
|
||||||
// The output segmented masks grouped in a vector.
|
// The output confidence masks grouped in a vector.
|
||||||
|
// CATEGORY_MASK - mediapipe::Image @Optional
|
||||||
|
// Optional Category mask.
|
||||||
// IMAGE - mediapipe::Image
|
// IMAGE - mediapipe::Image
|
||||||
// The image that image segmenter runs on.
|
// The image that image segmenter runs on.
|
||||||
//
|
//
|
||||||
|
@ -369,23 +381,39 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||||
CreateModelResources<ImageSegmenterGraphOptions>(sc));
|
CreateModelResources<ImageSegmenterGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
const auto& options = sc->Options<ImageSegmenterGraphOptions>();
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto output_streams,
|
||||||
BuildSegmentationTask(
|
BuildSegmentationTask(
|
||||||
sc->Options<ImageSegmenterGraphOptions>(), *model_resources,
|
options, *model_resources, graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
HasOutput(sc->OriginalNode(), kCategoryMaskTag), graph));
|
||||||
|
|
||||||
auto& merge_images_to_vector =
|
auto& merge_images_to_vector =
|
||||||
graph.AddNode("MergeImagesToVectorCalculator");
|
graph.AddNode("MergeImagesToVectorCalculator");
|
||||||
for (int i = 0; i < output_streams.segmented_masks.size(); ++i) {
|
// TODO: remove deprecated output type support.
|
||||||
output_streams.segmented_masks[i] >>
|
if (options.segmenter_options().has_output_type()) {
|
||||||
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
for (int i = 0; i < output_streams.segmented_masks->size(); ++i) {
|
||||||
output_streams.segmented_masks[i] >>
|
output_streams.segmented_masks->at(i) >>
|
||||||
graph[Output<Image>::Multiple(kSegmentationTag)][i];
|
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
||||||
|
output_streams.segmented_masks->at(i) >>
|
||||||
|
graph[Output<Image>::Multiple(kSegmentationTag)][i];
|
||||||
|
}
|
||||||
|
merge_images_to_vector.Out("") >>
|
||||||
|
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < output_streams.confidence_masks->size(); ++i) {
|
||||||
|
output_streams.confidence_masks->at(i) >>
|
||||||
|
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
||||||
|
output_streams.confidence_masks->at(i) >>
|
||||||
|
graph[Output<Image>::Multiple(kConfidenceMaskTag)][i];
|
||||||
|
}
|
||||||
|
merge_images_to_vector.Out("") >>
|
||||||
|
graph[Output<std::vector<Image>>(kConfidenceMasksTag)];
|
||||||
|
if (output_streams.category_mask) {
|
||||||
|
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
merge_images_to_vector.Out("") >>
|
|
||||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -403,7 +431,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||||
const ImageSegmenterGraphOptions& task_options,
|
const ImageSegmenterGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, bool output_category_mask,
|
||||||
|
Graph& graph) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||||
|
|
||||||
// Adds preprocessing calculators and connects them to the graph input image
|
// Adds preprocessing calculators and connects them to the graph input image
|
||||||
|
@ -435,22 +464,46 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag);
|
image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag);
|
||||||
|
|
||||||
// Exports multiple segmented masks.
|
// Exports multiple segmented masks.
|
||||||
std::vector<Source<Image>> segmented_masks;
|
// TODO: remove deprecated output type support.
|
||||||
if (task_options.segmenter_options().output_type() ==
|
if (task_options.segmenter_options().has_output_type()) {
|
||||||
SegmenterOptions::CATEGORY_MASK) {
|
std::vector<Source<Image>> segmented_masks;
|
||||||
segmented_masks.push_back(
|
if (task_options.segmenter_options().output_type() ==
|
||||||
Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)]));
|
SegmenterOptions::CATEGORY_MASK) {
|
||||||
|
segmented_masks.push_back(
|
||||||
|
Source<Image>(tensor_to_images[Output<Image>(kSegmentationTag)]));
|
||||||
|
} else {
|
||||||
|
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
||||||
|
GetOutputTensor(model_resources));
|
||||||
|
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||||
|
for (int i = 0; i < segmentation_streams_num; ++i) {
|
||||||
|
segmented_masks.push_back(Source<Image>(
|
||||||
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||||
|
/*confidence_masks=*/std::nullopt,
|
||||||
|
/*category_mask=*/std::nullopt,
|
||||||
|
/*image=*/image_and_tensors.image};
|
||||||
} else {
|
} else {
|
||||||
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
||||||
GetOutputTensor(model_resources));
|
GetOutputTensor(model_resources));
|
||||||
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||||
|
std::vector<Source<Image>> confidence_masks;
|
||||||
|
confidence_masks.reserve(segmentation_streams_num);
|
||||||
for (int i = 0; i < segmentation_streams_num; ++i) {
|
for (int i = 0; i < segmentation_streams_num; ++i) {
|
||||||
segmented_masks.push_back(Source<Image>(
|
confidence_masks.push_back(Source<Image>(
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kConfidenceMaskTag)][i]));
|
||||||
}
|
}
|
||||||
|
return ImageSegmenterOutputs{
|
||||||
|
/*segmented_masks=*/std::nullopt,
|
||||||
|
/*confidence_masks=*/confidence_masks,
|
||||||
|
/*category_mask=*/
|
||||||
|
output_category_mask
|
||||||
|
? std::make_optional(
|
||||||
|
tensor_to_images[Output<Image>(kCategoryMaskTag)])
|
||||||
|
: std::nullopt,
|
||||||
|
/*image=*/image_and_tensors.image};
|
||||||
}
|
}
|
||||||
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
|
||||||
/*image=*/image_and_tensors.image};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace image_segmenter {
|
||||||
|
|
||||||
|
// The output result of ImageSegmenter
|
||||||
|
struct ImageSegmenterResult {
|
||||||
|
// Multiple masks of float image in VEC32F1 format where, for each mask, each
|
||||||
|
// pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
std::vector<Image> confidence_masks;
|
||||||
|
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||||
|
// the class which the pixel in the original image was predicted to belong to.
|
||||||
|
std::optional<Image> category_mask;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace image_segmenter
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_RESULT_H_
|
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
@ -256,7 +257,6 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
@ -278,15 +278,14 @@ TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_category_mask = true;
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(category_masks.size(), 1);
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
|
|
||||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||||
category_masks[0].GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
|
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
||||||
|
@ -303,12 +302,11 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 21);
|
EXPECT_EQ(result.confidence_masks.size(), 21);
|
||||||
|
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
|
JoinPath("./", kTestDataDirectory, "cat_mask.jpg"), cv::IMREAD_GRAYSCALE);
|
||||||
|
@ -317,7 +315,7 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
|
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[8].GetImageFrameSharedPtr().get());
|
result.confidence_masks[8].GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(cat_mask,
|
EXPECT_THAT(cat_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -331,15 +329,14 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
ImageProcessingOptions image_processing_options;
|
ImageProcessingOptions image_processing_options;
|
||||||
image_processing_options.rotation_degrees = -90;
|
image_processing_options.rotation_degrees = -90;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||||
segmenter->Segment(image, image_processing_options));
|
segmenter->Segment(image, image_processing_options));
|
||||||
EXPECT_EQ(confidence_masks.size(), 21);
|
EXPECT_EQ(result.confidence_masks.size(), 21);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
||||||
|
@ -349,7 +346,7 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||||
|
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[8].GetImageFrameSharedPtr().get());
|
result.confidence_masks[8].GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(cat_mask,
|
EXPECT_THAT(cat_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -361,7 +358,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
@ -384,12 +380,11 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 2);
|
EXPECT_EQ(result.confidence_masks.size(), 2);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory,
|
cv::imread(JoinPath("./", kTestDataDirectory,
|
||||||
|
@ -400,7 +395,7 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||||
|
|
||||||
// Selfie category index 1.
|
// Selfie category index 1.
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[1].GetImageFrameSharedPtr().get());
|
result.confidence_masks[1].GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -411,11 +406,10 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 1);
|
EXPECT_EQ(result.confidence_masks.size(), 1);
|
||||||
|
|
||||||
cv::Mat expected_mask =
|
cv::Mat expected_mask =
|
||||||
cv::imread(JoinPath("./", kTestDataDirectory,
|
cv::imread(JoinPath("./", kTestDataDirectory,
|
||||||
|
@ -425,7 +419,7 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
||||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[0].GetImageFrameSharedPtr().get());
|
result.confidence_masks[0].GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -436,12 +430,11 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 1);
|
EXPECT_EQ(result.confidence_masks.size(), 1);
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
|
@ -452,7 +445,7 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[0].GetImageFrameSharedPtr().get());
|
result.confidence_masks[0].GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(selfie_mask,
|
EXPECT_THAT(selfie_mask,
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
@ -463,16 +456,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_category_mask = true;
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(category_mask.size(), 1);
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
category_mask[0].GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory,
|
JoinPath("./", kTestDataDirectory,
|
||||||
"portrait_selfie_segmentation_expected_category_mask.jpg"),
|
"portrait_selfie_segmentation_expected_category_mask.jpg"),
|
||||||
|
@ -487,16 +479,15 @@ TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
|
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_category_mask = true;
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(category_mask.size(), 1);
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
|
|
||||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||||
category_mask[0].GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath(
|
JoinPath(
|
||||||
"./", kTestDataDirectory,
|
"./", kTestDataDirectory,
|
||||||
|
@ -512,14 +503,13 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
|
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 2);
|
EXPECT_EQ(result.confidence_masks.size(), 2);
|
||||||
|
|
||||||
cv::Mat hair_mask = mediapipe::formats::MatView(
|
cv::Mat hair_mask = mediapipe::formats::MatView(
|
||||||
confidence_masks[1].GetImageFrameSharedPtr().get());
|
result.confidence_masks[1].GetImageFrameSharedPtr().get());
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
|
JoinPath("./", kTestDataDirectory, "portrait_hair_expected_mask.jpg"),
|
||||||
|
@ -540,7 +530,6 @@ TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
|
||||||
options->running_mode = core::RunningMode::VIDEO;
|
options->running_mode = core::RunningMode::VIDEO;
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
@ -572,7 +561,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_category_mask = true;
|
||||||
options->running_mode = core::RunningMode::VIDEO;
|
options->running_mode = core::RunningMode::VIDEO;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
@ -580,11 +569,10 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
||||||
cv::IMREAD_GRAYSCALE);
|
cv::IMREAD_GRAYSCALE);
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
MP_ASSERT_OK_AND_ASSIGN(auto result, segmenter->SegmentForVideo(image, i));
|
||||||
segmenter->SegmentForVideo(image, i));
|
EXPECT_TRUE(result.category_mask.has_value());
|
||||||
EXPECT_EQ(category_masks.size(), 1);
|
|
||||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||||
category_masks[0].GetImageFrameSharedPtr().get());
|
result.category_mask->GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(actual_mask,
|
EXPECT_THAT(actual_mask,
|
||||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
|
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
|
||||||
kGoldenMaskMagnificationFactor));
|
kGoldenMaskMagnificationFactor));
|
||||||
|
@ -601,11 +589,10 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image,
|
[](absl::StatusOr<ImageSegmenterResult> segmented_masks,
|
||||||
int64 timestamp_ms) {};
|
const Image& image, int64_t timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
|
||||||
|
@ -634,11 +621,9 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback = [](absl::StatusOr<ImageSegmenterResult> result,
|
||||||
[](absl::StatusOr<std::vector<Image>> segmented_masks, const Image& image,
|
const Image& image, int64_t timestamp_ms) {};
|
||||||
int64 timestamp_ms) {};
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK(segmenter->SegmentAsync(image, 1));
|
MP_ASSERT_OK(segmenter->SegmentAsync(image, 1));
|
||||||
|
@ -660,23 +645,23 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
"segmentation_input_rotation0.jpg")));
|
"segmentation_input_rotation0.jpg")));
|
||||||
std::vector<std::vector<Image>> segmented_masks_results;
|
std::vector<Image> segmented_masks_results;
|
||||||
std::vector<std::pair<int, int>> image_sizes;
|
std::vector<std::pair<int, int>> image_sizes;
|
||||||
std::vector<int64> timestamps;
|
std::vector<int64_t> timestamps;
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
options->output_category_mask = true;
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback = [&segmented_masks_results, &image_sizes,
|
||||||
[&segmented_masks_results, &image_sizes, ×tamps](
|
×tamps](
|
||||||
absl::StatusOr<std::vector<Image>> segmented_masks,
|
absl::StatusOr<ImageSegmenterResult> result,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64_t timestamp_ms) {
|
||||||
MP_ASSERT_OK(segmented_masks.status());
|
MP_ASSERT_OK(result.status());
|
||||||
segmented_masks_results.push_back(std::move(segmented_masks).value());
|
segmented_masks_results.push_back(std::move(*result->category_mask));
|
||||||
image_sizes.push_back({image.width(), image.height()});
|
image_sizes.push_back({image.width(), image.height()});
|
||||||
timestamps.push_back(timestamp_ms);
|
timestamps.push_back(timestamp_ms);
|
||||||
};
|
};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
|
@ -690,10 +675,9 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
cv::Mat expected_mask = cv::imread(
|
cv::Mat expected_mask = cv::imread(
|
||||||
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
JoinPath("./", kTestDataDirectory, "segmentation_golden_rotation0.png"),
|
||||||
cv::IMREAD_GRAYSCALE);
|
cv::IMREAD_GRAYSCALE);
|
||||||
for (const auto& segmented_masks : segmented_masks_results) {
|
for (const auto& category_mask : segmented_masks_results) {
|
||||||
EXPECT_EQ(segmented_masks.size(), 1);
|
|
||||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||||
segmented_masks[0].GetImageFrameSharedPtr().get());
|
category_mask.GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(actual_mask,
|
EXPECT_THAT(actual_mask,
|
||||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
|
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity,
|
||||||
kGoldenMaskMagnificationFactor));
|
kGoldenMaskMagnificationFactor));
|
||||||
|
@ -702,7 +686,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
EXPECT_EQ(image_size.first, image.width());
|
EXPECT_EQ(image_size.first, image.width());
|
||||||
EXPECT_EQ(image_size.second, image.height());
|
EXPECT_EQ(image_size.second, image.height());
|
||||||
}
|
}
|
||||||
int64 timestamp_ms = -1;
|
int64_t timestamp_ms = -1;
|
||||||
for (const auto& timestamp : timestamps) {
|
for (const auto& timestamp : timestamps) {
|
||||||
EXPECT_GT(timestamp, timestamp_ms);
|
EXPECT_GT(timestamp, timestamp_ms);
|
||||||
timestamp_ms = timestamp;
|
timestamp_ms = timestamp;
|
||||||
|
|
|
@ -33,7 +33,7 @@ message SegmenterOptions {
|
||||||
CONFIDENCE_MASK = 2;
|
CONFIDENCE_MASK = 2;
|
||||||
}
|
}
|
||||||
// Optional output mask type.
|
// Optional output mask type.
|
||||||
optional OutputType output_type = 1 [default = CATEGORY_MASK];
|
optional OutputType output_type = 1 [deprecated = true];
|
||||||
|
|
||||||
// Supported activation functions for filtering.
|
// Supported activation functions for filtering.
|
||||||
enum Activation {
|
enum Activation {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user