Add support for rotation in ImageEmbedder & ImageSegmenter C++ APIs

PiperOrigin-RevId: 483416498
This commit is contained in:
MediaPipe Team 2022-10-24 10:12:41 -07:00 committed by Copybara-Service
parent 0fd69e8d83
commit 2f2baeff68
11 changed files with 301 additions and 72 deletions

View File

@ -58,6 +58,7 @@ cc_library(
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto::
ImageEmbedderGraphOptions;
// Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() {
NormalizedRect norm_rect;
norm_rect.set_x_center(0.5);
norm_rect.set_y_center(0.5);
norm_rect.set_width(1);
norm_rect.set_height(1);
return norm_rect;
}
// Creates a MediaPipe graph config that contains a single node of type
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
Image image, std::optional<NormalizedRect> roi) {
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
}
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms,
std::optional<NormalizedRect> roi) {
absl::Status ImageEmbedder::EmbedAsync(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
std::unique_ptr<ImageEmbedderOptions> options);
// Performs embedding extraction on the provided single image. Extraction
// is performed on the region of interest specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs embedding extraction on the provided single image.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the image
// running mode.
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs embedding extraction on the provided video frame. Extraction
// is performed on the region of interested specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs embedding extraction on the provided video frame.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the video
// running mode.
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to embedder, and the results will be available via
// the "result_callback" provided in the ImageEmbedderOptions. Embedding
// extraction is performed on the region of interested specified by the `roi`
// argument if provided, or on the entire image otherwise.
// the "result_callback" provided in the ImageEmbedderOptions.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the live
// stream running mode.
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status EmbedAsync(
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
@ -42,7 +41,9 @@ namespace image_embedder {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg".
NormalizedRect roi;
roi.set_x_center(200.0 / 480);
roi.set_y_center(0.5);
roi.set_width(400.0 / 480);
roi.set_height(1.0f);
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image, roi));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result,
image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a rotated version of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {

View File

@ -24,10 +24,12 @@ cc_library(
":image_segmenter_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
@ -48,6 +50,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",

View File

@ -17,8 +17,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
{kImageTag, kNormRectTag},
kGroupedSegmentationTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
mediapipe::Image image) {
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData({{kImageInStreamName,
mediapipe::MakePacket<Image>(std::move(image))}}));
ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms) {
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) {
absl::Status ImageSegmenter::SegmentAsync(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/kernels/register.h"
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// running mode.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms);
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// sent to the image segmenter. The input timestamps must be monotonically
// increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides
// - A vector of segmented image masks.
// If the output_type is CATEGORY_MASK, the returned vector of images is
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms);
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
// Inputs:
// IMAGE - Image
// Image to perform segmentation on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection
// on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto output_streams,
BuildSegmentationTask(
sc->Options<ImageSegmenterOptions>(), *model_resources,
graph[Input<Image>(kImageTag)], graph));
ASSIGN_OR_RETURN(
auto output_streams,
BuildSegmentationTask(
sc->Options<ImageSegmenterOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator");
@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image
@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.

View File

@ -29,8 +29,10 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -44,6 +46,8 @@ namespace {
using ::mediapipe::Image;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
// Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = segmenter->Segment(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));

View File

@ -28,6 +28,8 @@ mediapipe_files(srcs = [
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cat_rotated.jpg",
"cat_rotated_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",
@ -84,6 +86,8 @@ filegroup(
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cat_rotated.jpg",
"cat_rotated_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",

View File

@ -76,6 +76,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"],
)
http_file(
name = "com_google_mediapipe_cat_rotated_jpg",
sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649",
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"],
)
http_file(
name = "com_google_mediapipe_cat_rotated_mask_jpg",
sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861",
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"],
)
http_file(
name = "com_google_mediapipe_cats_and_dogs_jpg",
sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c",
@ -162,8 +174,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666304169636598"],
)
http_file(
@ -174,8 +186,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666304171758037"],
)
http_file(
@ -258,8 +270,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666304174234283"],
)
http_file(
@ -606,8 +618,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666304178388806"],
)
http_file(
@ -798,8 +810,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666304181397432"],
)
http_file(