Internal change
PiperOrigin-RevId: 522631851
This commit is contained in:
parent
e3185e3df0
commit
a1ce19f68e
|
@ -24,21 +24,26 @@ cc_library(
|
|||
hdrs = ["interactive_segmenter.h"],
|
||||
deps = [
|
||||
":interactive_segmenter_graph",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_result",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -61,9 +66,12 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
|
|
|
@ -15,16 +15,24 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
|
@ -36,23 +44,26 @@ namespace vision {
|
|||
namespace interactive_segmenter {
|
||||
namespace {
|
||||
|
||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||
constexpr char kConfidenceMasksStreamName[] = "confidence_masks";
|
||||
constexpr char kCategoryMaskStreamName[] = "category_mask";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kRoiStreamName[] = "roi_in";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
constexpr absl::string_view kSubgraphTypeName{
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
|
@ -60,7 +71,8 @@ using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
|||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
|
||||
bool output_confidence_masks, bool output_category_mask) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||
|
@ -68,8 +80,15 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kRoiTag).SetName(kRoiStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
graph.Out(kGroupedSegmentationTag);
|
||||
if (output_confidence_masks) {
|
||||
task_subgraph.Out(kConfidenceMasksTag)
|
||||
.SetName(kConfidenceMasksStreamName) >>
|
||||
graph.Out(kConfidenceMasksTag);
|
||||
}
|
||||
if (output_category_mask) {
|
||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||
graph.Out(kCategoryMaskTag);
|
||||
}
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
|
@ -86,16 +105,6 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
|||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
switch (options->output_type) {
|
||||
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CATEGORY_MASK);
|
||||
break;
|
||||
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CONFIDENCE_MASK);
|
||||
break;
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
|
@ -104,10 +113,10 @@ ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
|||
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||
RenderData result;
|
||||
switch (roi.format) {
|
||||
case RegionOfInterest::UNSPECIFIED:
|
||||
case RegionOfInterest::Format::kUnspecified:
|
||||
return absl::InvalidArgumentError(
|
||||
"RegionOfInterest format not specified");
|
||||
case RegionOfInterest::KEYPOINT:
|
||||
case RegionOfInterest::Format::kKeyPoint:
|
||||
RET_CHECK(roi.keypoint.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
|
@ -125,15 +134,29 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
|||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
||||
InteractiveSegmenter::Create(
|
||||
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
||||
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||
if (!options->output_confidence_masks && !options->output_category_mask) {
|
||||
return absl::InvalidArgumentError(
|
||||
"At least one of `output_confidence_masks` and `output_category_mask` "
|
||||
"must be set.");
|
||||
}
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options_proto =
|
||||
ConvertImageSegmenterOptionsToProto(options.get());
|
||||
ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
(core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
|
||||
/*packets_callback=*/nullptr);
|
||||
CreateGraphConfig(std::move(options_proto),
|
||||
options->output_confidence_masks,
|
||||
options->output_category_mask),
|
||||
std::move(options->base_options.op_resolver),
|
||||
core::RunningMode::IMAGE,
|
||||
/*packets_callback=*/nullptr)));
|
||||
segmenter->output_category_mask_ = options->output_category_mask;
|
||||
segmenter->output_confidence_masks_ = options->output_confidence_masks;
|
||||
return segmenter;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
||||
absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -154,7 +177,16 @@ absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
|||
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
std::optional<std::vector<Image>> confidence_masks;
|
||||
if (output_confidence_masks_) {
|
||||
confidence_masks =
|
||||
output_packets[kConfidenceMasksStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
std::optional<Image> category_mask;
|
||||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
}
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
|
|
|
@ -21,12 +21,14 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -39,30 +41,24 @@ struct InteractiveSegmenterOptions {
|
|||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// The output type of segmentation results.
|
||||
enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK = 0,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK = 1,
|
||||
};
|
||||
// Whether to output confidence masks.
|
||||
bool output_confidence_masks = true;
|
||||
|
||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||
// Whether to output category mask.
|
||||
bool output_category_mask = false;
|
||||
};
|
||||
|
||||
// The Region-Of-Interest (ROI) to interact with.
|
||||
struct RegionOfInterest {
|
||||
enum Format {
|
||||
UNSPECIFIED = 0, // Format not specified.
|
||||
KEYPOINT = 1, // Using keypoint to represent ROI.
|
||||
enum class Format {
|
||||
kUnspecified = 0, // Format not specified.
|
||||
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||
};
|
||||
|
||||
// Specifies the format used to specify the region-of-interest. Note that
|
||||
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
||||
// being returned.
|
||||
Format format = Format::UNSPECIFIED;
|
||||
Format format = Format::kUnspecified;
|
||||
|
||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||
// `format` is `KEYPOINT`.
|
||||
|
@ -84,13 +80,11 @@ struct RegionOfInterest {
|
|||
// - RGB inputs is supported (`channels` is required to be 3).
|
||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
// attached to the metadata for input normalization.
|
||||
// Output tensors:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - list of segmented masks.
|
||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
// `channels`.
|
||||
// - batch is always 1
|
||||
// Output ImageSegmenterResult:
|
||||
// Provides optional confidence masks if `output_confidence_masks` is set
|
||||
// true, and an optional category mask if `output_category_mask` is set
|
||||
// true. At least one of `output_confidence_masks` and `output_category_mask`
|
||||
// must be set to true.
|
||||
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
@ -114,18 +108,17 @@ class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||
absl::StatusOr<image_segmenter::ImageSegmenterResult> Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Shuts down the InteractiveSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
private:
|
||||
bool output_confidence_masks_;
|
||||
bool output_category_mask_;
|
||||
};
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
|
|
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
@ -23,8 +26,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/graph_builder_utils.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -42,16 +46,18 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||
constexpr char kAlphaTag[] = "ALPHA";
|
||||
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr absl::string_view kSegmentationTag{"SEGMENTATION"};
|
||||
constexpr absl::string_view kGroupedSegmentationTag{"GROUPED_SEGMENTATION"};
|
||||
constexpr absl::string_view kConfidenceMaskTag{"CONFIDENCE_MASK"};
|
||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||
constexpr absl::string_view kImageCpuTag{"IMAGE_CPU"};
|
||||
constexpr absl::string_view kImageGpuTag{"IMAGE_GPU"};
|
||||
constexpr absl::string_view kAlphaTag{"ALPHA"};
|
||||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
|
||||
// Updates the graph to return `roi` stream which has same dimension as
|
||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||
|
@ -87,11 +93,10 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
|||
} // namespace
|
||||
|
||||
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||
// performs semantic segmentation given user's region-of-interest. Two kinds of
|
||||
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
|
||||
// retrieve segmented mask of only particular category/channel from
|
||||
// SEGMENTATION, and users can also get all segmented masks from
|
||||
// GROUPED_SEGMENTATION.
|
||||
// performs semantic segmentation given the user's region-of-interest. The graph
|
||||
// can output optional confidence masks if CONFIDENCE_MASKS is connected, and an
|
||||
// optional category mask if CATEGORY_MASK is connected. At least one of
|
||||
// CONFIDENCE_MASK, CONFIDENCE_MASKS and CATEGORY_MASK must be connected.
|
||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -106,11 +111,13 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
|||
// @Optional: rect covering the whole image is used if not specified.
|
||||
//
|
||||
// Outputs:
|
||||
// SEGMENTATION - mediapipe::Image @Multiple
|
||||
// Segmented masks for individual category. Segmented mask of single
|
||||
// CONFIDENCE_MASK - mediapipe::Image @Multiple
|
||||
// Confidence masks for individual category. Confidence mask of single
|
||||
// category can be accessed by index based output stream.
|
||||
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
||||
// The output segmented masks grouped in a vector.
|
||||
// CONFIDENCE_MASKS - std::vector<mediapipe::Image> @Optional
|
||||
// The output confidence masks grouped in a vector.
|
||||
// CATEGORY_MASK - mediapipe::Image @Optional
|
||||
// Optional Category mask.
|
||||
// IMAGE - mediapipe::Image
|
||||
// The image that image segmenter runs on.
|
||||
//
|
||||
|
@ -129,9 +136,6 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
|||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// segmenter_options {
|
||||
// output_type: CONFIDENCE_MASK
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
@ -176,10 +180,26 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
|||
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
||||
norm_rect >> image_segmenter.In(kNormRectTag);
|
||||
|
||||
// TODO: remove deprecated output type support.
|
||||
if (task_options.segmenter_options().has_output_type()) {
|
||||
image_segmenter.Out(kSegmentationTag) >>
|
||||
graph[Output<Image>(kSegmentationTag)];
|
||||
image_segmenter.Out(kGroupedSegmentationTag) >>
|
||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||
} else {
|
||||
if (HasOutput(sc->OriginalNode(), kConfidenceMaskTag)) {
|
||||
image_segmenter.Out(kConfidenceMaskTag) >>
|
||||
graph[Output<Image>(kConfidenceMaskTag)];
|
||||
}
|
||||
if (HasOutput(sc->OriginalNode(), kConfidenceMasksTag)) {
|
||||
image_segmenter.Out(kConfidenceMasksTag) >>
|
||||
graph[Output<Image>(kConfidenceMasksTag)];
|
||||
}
|
||||
if (HasOutput(sc->OriginalNode(), kCategoryMaskTag)) {
|
||||
image_segmenter.Out(kCategoryMaskTag) >>
|
||||
graph[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
}
|
||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
|
|
|
@ -17,8 +17,11 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
@ -39,6 +42,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -53,13 +57,16 @@ using ::mediapipe::tasks::components::containers::RectF;
|
|||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::SizeIs;
|
||||
using ::testing::status::StatusIs;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
|
||||
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
|
||||
constexpr absl::string_view kTestDataDirectory{
|
||||
"/mediapipe/tasks/testdata/vision/"};
|
||||
constexpr absl::string_view kPtmModel{"ptm_512_hdt_ptm_woid.tflite"};
|
||||
constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
|
||||
// Golden mask for the dogs in cats_and_dogs.jpg.
|
||||
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
|
||||
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
|
||||
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
|
||||
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
|
||||
|
||||
constexpr float kGoldenMaskSimilarity = 0.97;
|
||||
|
||||
|
@ -135,35 +142,45 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
|||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
|
||||
auto segmenter = InteractiveSegmenter::Create(std::move(options));
|
||||
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
segmenter.status().message(),
|
||||
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter =
|
||||
InteractiveSegmenter::Create(
|
||||
std::make_unique<InteractiveSegmenterOptions>());
|
||||
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_EQ(segmenter.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
segmenter.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
EXPECT_THAT(segmenter.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->output_category_mask = false;
|
||||
options->output_confidence_masks = false;
|
||||
|
||||
EXPECT_THAT(InteractiveSegmenter::Create(std::move(options)),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("At least one of")));
|
||||
}
|
||||
|
||||
struct InteractiveSegmenterTestParams {
|
||||
std::string test_name;
|
||||
RegionOfInterest::Format format;
|
||||
NormalizedKeypoint roi;
|
||||
std::string golden_mask_file;
|
||||
absl::string_view golden_mask_file;
|
||||
float similarity_threshold;
|
||||
};
|
||||
|
||||
|
@ -181,16 +198,18 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
|||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
options->output_confidence_masks = false;
|
||||
options->output_category_mask = true;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(category_masks.size(), 1);
|
||||
EXPECT_TRUE(result.category_mask.has_value());
|
||||
EXPECT_FALSE(result.confidence_masks.has_value());
|
||||
|
||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||
category_masks[0].GetImageFrameSharedPtr().get());
|
||||
result.category_mask->GetImageFrameSharedPtr().get());
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||
|
@ -211,14 +230,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
|||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
EXPECT_FALSE(result.category_mask.has_value());
|
||||
EXPECT_THAT(result.confidence_masks, Optional(SizeIs(2)));
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||
|
@ -227,7 +245,7 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
|||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||
|
||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||
confidence_masks[1].GetImageFrameSharedPtr().get());
|
||||
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
||||
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
||||
params.similarity_threshold));
|
||||
}
|
||||
|
@ -235,9 +253,9 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
|||
INSTANTIATE_TEST_SUITE_P(
|
||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||
{{"PointToDog1", RegionOfInterest::KEYPOINT,
|
||||
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"PointToDog2", RegionOfInterest::KEYPOINT,
|
||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||
kGoldenMaskSimilarity}}),
|
||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||
|
@ -252,22 +270,21 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
|||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
|
||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto confidence_masks,
|
||||
auto result,
|
||||
segmenter->Segment(image, interaction_roi, image_processing_options));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
EXPECT_FALSE(result.category_mask.has_value());
|
||||
EXPECT_EQ(result.confidence_masks->size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
|
@ -275,13 +292,11 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.format = RegionOfInterest::Format::kKeyPoint;
|
||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
|
|
Loading…
Reference in New Issue
Block a user