API for c++ ImageSegmenter to get labels

PiperOrigin-RevId: 516714139
This commit is contained in:
MediaPipe Team 2023-03-14 21:10:43 -07:00 committed by Copybara-Service
parent cafff14135
commit f517eddce1
4 changed files with 95 additions and 7 deletions

View File

@ -25,6 +25,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":image_segmenter_graph",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
@ -34,10 +35,13 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:label_map_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)

View File

@ -15,15 +15,21 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
#include <optional>
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/label_map.pb.h"
namespace mediapipe {
namespace tasks {
@ -112,6 +118,39 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
return options_proto;
}
absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
const CalculatorGraphConfig& graph_config) {
bool found_tensor_to_segmentation_calculator = false;
std::vector<std::string> labels;
for (const auto& node : graph_config.node()) {
if (node.calculator() ==
"mediapipe.tasks.TensorsToSegmentationCalculator") {
if (!found_tensor_to_segmentation_calculator) {
found_tensor_to_segmentation_calculator = true;
} else {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kFailedPrecondition,
"The graph has more than one "
"mediapipe.tasks.TensorsToSegmentationCalculator."));
}
TensorsToSegmentationCalculatorOptions options =
node.options().GetExtension(
TensorsToSegmentationCalculatorOptions::ext);
if (!options.label_items().empty()) {
for (int i = 0; i < options.label_items_size(); ++i) {
if (!options.label_items().contains(i)) {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kFailedPrecondition,
absl::StrFormat("The lablemap have no expected key: %d.", i)));
}
labels.push_back(options.label_items().at(i).name());
}
}
}
}
return labels;
}
} // namespace
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
@ -140,13 +179,22 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<ImageSegmenter,
auto image_segmenter =
core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
if (!image_segmenter.ok()) {
return image_segmenter.status();
}
ASSIGN_OR_RETURN(
(*image_segmenter)->labels_,
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
return image_segmenter;
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(

View File

@ -189,6 +189,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }
// Get the category label list of the ImageSegmenter can recognize. For
// CATEGORY_MASK type, the index in the category mask corresponds to the
// category in the label list. For CONFIDENCE_MASK type, the output mask list
// at index corresponds to the category in the label list.
//
// If there is no labelmap provided in the model file, empty label list is
// returned.
std::vector<std::string> GetLabels() { return labels_; }
private:
std::vector<std::string> labels_;
};
} // namespace image_segmenter

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
#include <array>
#include <cstdint>
#include <memory>
@ -71,6 +72,13 @@ constexpr float kGoldenMaskSimilarity = 0.98;
// 20 means class index 2, etc.
constexpr int kGoldenMaskMagnificationFactor = 10;
constexpr std::array<absl::string_view, 21> kDeeplabLabelNames = {
"background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "dining table", "dog", "horse", "motorbike",
"person", "potted plant", "sheep", "sofa", "train",
"tv"};
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
// fair comparison.
@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
"channels = 3 or 4."));
}
TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
const auto& labels = segmenter->GetLabels();
ASSERT_FALSE(labels.empty());
ASSERT_EQ(labels.size(), kDeeplabLabelNames.size());
for (int i = 0; i < labels.size(); ++i) {
EXPECT_EQ(labels[i], kDeeplabLabelNames[i]);
}
}
class ImageModeTest : public tflite_shims::testing::Test {};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {