API for c++ ImageSegmenter to get labels
PiperOrigin-RevId: 516714139
This commit is contained in:
parent
cafff14135
commit
f517eddce1
|
@ -25,6 +25,7 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":image_segmenter_graph",
|
":image_segmenter_graph",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
@ -34,10 +35,13 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||||
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,15 +15,21 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||||
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -112,6 +118,39 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
|
||||||
|
const CalculatorGraphConfig& graph_config) {
|
||||||
|
bool found_tensor_to_segmentation_calculator = false;
|
||||||
|
std::vector<std::string> labels;
|
||||||
|
for (const auto& node : graph_config.node()) {
|
||||||
|
if (node.calculator() ==
|
||||||
|
"mediapipe.tasks.TensorsToSegmentationCalculator") {
|
||||||
|
if (!found_tensor_to_segmentation_calculator) {
|
||||||
|
found_tensor_to_segmentation_calculator = true;
|
||||||
|
} else {
|
||||||
|
return absl::Status(CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kFailedPrecondition,
|
||||||
|
"The graph has more than one "
|
||||||
|
"mediapipe.tasks.TensorsToSegmentationCalculator."));
|
||||||
|
}
|
||||||
|
TensorsToSegmentationCalculatorOptions options =
|
||||||
|
node.options().GetExtension(
|
||||||
|
TensorsToSegmentationCalculatorOptions::ext);
|
||||||
|
if (!options.label_items().empty()) {
|
||||||
|
for (int i = 0; i < options.label_items_size(); ++i) {
|
||||||
|
if (!options.label_items().contains(i)) {
|
||||||
|
return absl::Status(CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kFailedPrecondition,
|
||||||
|
absl::StrFormat("The lablemap have no expected key: %d.", i)));
|
||||||
|
}
|
||||||
|
labels.push_back(options.label_items().at(i).name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return labels;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
|
@ -140,13 +179,22 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
kMicroSecondsPerMilliSecond);
|
kMicroSecondsPerMilliSecond);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return core::VisionTaskApiFactory::Create<ImageSegmenter,
|
|
||||||
ImageSegmenterGraphOptionsProto>(
|
auto image_segmenter =
|
||||||
CreateGraphConfig(
|
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||||
std::move(options_proto),
|
ImageSegmenterGraphOptionsProto>(
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
CreateGraphConfig(
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options_proto),
|
||||||
std::move(packets_callback));
|
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||||
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
|
std::move(packets_callback));
|
||||||
|
if (!image_segmenter.ok()) {
|
||||||
|
return image_segmenter.status();
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
(*image_segmenter)->labels_,
|
||||||
|
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
|
||||||
|
return image_segmenter;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||||
|
|
|
@ -189,6 +189,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
|
||||||
// Shuts down the ImageSegmenter when all works are done.
|
// Shuts down the ImageSegmenter when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
|
// Get the category label list of the ImageSegmenter can recognize. For
|
||||||
|
// CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||||
|
// category in the label list. For CONFIDENCE_MASK type, the output mask list
|
||||||
|
// at index corresponds to the category in the label list.
|
||||||
|
//
|
||||||
|
// If there is no labelmap provided in the model file, empty label list is
|
||||||
|
// returned.
|
||||||
|
std::vector<std::string> GetLabels() { return labels_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::string> labels_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_segmenter
|
} // namespace image_segmenter
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -71,6 +72,13 @@ constexpr float kGoldenMaskSimilarity = 0.98;
|
||||||
// 20 means class index 2, etc.
|
// 20 means class index 2, etc.
|
||||||
constexpr int kGoldenMaskMagnificationFactor = 10;
|
constexpr int kGoldenMaskMagnificationFactor = 10;
|
||||||
|
|
||||||
|
constexpr std::array<absl::string_view, 21> kDeeplabLabelNames = {
|
||||||
|
"background", "aeroplane", "bicycle", "bird", "boat",
|
||||||
|
"bottle", "bus", "car", "cat", "chair",
|
||||||
|
"cow", "dining table", "dog", "horse", "motorbike",
|
||||||
|
"person", "potted plant", "sheep", "sofa", "train",
|
||||||
|
"tv"};
|
||||||
|
|
||||||
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||||
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||||
// fair comparison.
|
// fair comparison.
|
||||||
|
@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
|
||||||
"channels = 3 or 4."));
|
"channels = 3 or 4."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
|
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
const auto& labels = segmenter->GetLabels();
|
||||||
|
ASSERT_FALSE(labels.empty());
|
||||||
|
ASSERT_EQ(labels.size(), kDeeplabLabelNames.size());
|
||||||
|
for (int i = 0; i < labels.size(); ++i) {
|
||||||
|
EXPECT_EQ(labels[i], kDeeplabLabelNames[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user