API for c++ ImageSegmenter to get labels
PiperOrigin-RevId: 516714139
This commit is contained in:
parent
cafff14135
commit
f517eddce1
|
@ -25,6 +25,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_segmenter_graph",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
|
@ -34,10 +35,13 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,15 +15,21 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -112,6 +118,39 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
|||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
|
||||
const CalculatorGraphConfig& graph_config) {
|
||||
bool found_tensor_to_segmentation_calculator = false;
|
||||
std::vector<std::string> labels;
|
||||
for (const auto& node : graph_config.node()) {
|
||||
if (node.calculator() ==
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator") {
|
||||
if (!found_tensor_to_segmentation_calculator) {
|
||||
found_tensor_to_segmentation_calculator = true;
|
||||
} else {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kFailedPrecondition,
|
||||
"The graph has more than one "
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator."));
|
||||
}
|
||||
TensorsToSegmentationCalculatorOptions options =
|
||||
node.options().GetExtension(
|
||||
TensorsToSegmentationCalculatorOptions::ext);
|
||||
if (!options.label_items().empty()) {
|
||||
for (int i = 0; i < options.label_items_size(); ++i) {
|
||||
if (!options.label_items().contains(i)) {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kFailedPrecondition,
|
||||
absl::StrFormat("The lablemap have no expected key: %d.", i)));
|
||||
}
|
||||
labels.push_back(options.label_items().at(i).name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||
|
@ -140,13 +179,22 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
|
||||
auto image_segmenter =
|
||||
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
if (!image_segmenter.ok()) {
|
||||
return image_segmenter.status();
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
(*image_segmenter)->labels_,
|
||||
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
|
||||
return image_segmenter;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||
|
|
|
@ -189,6 +189,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
|
||||
// Shuts down the ImageSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Get the category label list of the ImageSegmenter can recognize. For
|
||||
// CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||
// category in the label list. For CONFIDENCE_MASK type, the output mask list
|
||||
// at index corresponds to the category in the label list.
|
||||
//
|
||||
// If there is no labelmap provided in the model file, empty label list is
|
||||
// returned.
|
||||
std::vector<std::string> GetLabels() { return labels_; }
|
||||
|
||||
private:
|
||||
std::vector<std::string> labels_;
|
||||
};
|
||||
|
||||
} // namespace image_segmenter
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
|
@ -71,6 +72,13 @@ constexpr float kGoldenMaskSimilarity = 0.98;
|
|||
// 20 means class index 2, etc.
|
||||
constexpr int kGoldenMaskMagnificationFactor = 10;
|
||||
|
||||
constexpr std::array<absl::string_view, 21> kDeeplabLabelNames = {
|
||||
"background", "aeroplane", "bicycle", "bird", "boat",
|
||||
"bottle", "bus", "car", "cat", "chair",
|
||||
"cow", "dining table", "dog", "horse", "motorbike",
|
||||
"person", "potted plant", "sheep", "sofa", "train",
|
||||
"tv"};
|
||||
|
||||
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||
// fair comparison.
|
||||
|
@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
|
|||
"channels = 3 or 4."));
|
||||
}
|
||||
|
||||
TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
const auto& labels = segmenter->GetLabels();
|
||||
ASSERT_FALSE(labels.empty());
|
||||
ASSERT_EQ(labels.size(), kDeeplabLabelNames.size());
|
||||
for (int i = 0; i < labels.size(); ++i) {
|
||||
EXPECT_EQ(labels[i], kDeeplabLabelNames[i]);
|
||||
}
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user