From 5ef40282d1342df6dad2fe7ef39def6ef4f353a2 Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Thu, 10 Oct 2019 16:23:38 +0200 Subject: [PATCH] add a calculator that renders classifications Signed-off-by: Pierre Fenoll --- mediapipe/calculators/util/BUILD | 63 ++++++ ...assifications_to_render_data_calculator.cc | 179 ++++++++++++++++++ ...ifications_to_render_data_calculator.proto | 45 +++++ ...ications_to_render_data_calculator_test.cc | 70 +++++++ 4 files changed, 357 insertions(+) create mode 100644 mediapipe/calculators/util/classifications_to_render_data_calculator.cc create mode 100644 mediapipe/calculators/util/classifications_to_render_data_calculator.proto create mode 100644 mediapipe/calculators/util/classifications_to_render_data_calculator_test.cc diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 7bd06fe97..a6496a933 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -443,6 +443,17 @@ proto_library( ], ) +proto_library( + name = "classifications_to_render_data_calculator_proto", + srcs = ["classifications_to_render_data_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + "//mediapipe/util:render_data_proto", + ], +) + proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], @@ -546,6 +557,36 @@ cc_library( alwayslink = 1, ) +mediapipe_cc_proto_library( + name = "classifications_to_render_data_calculator_cc_proto", + srcs = ["classifications_to_render_data_calculator.proto"], + cc_deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ], + visibility = ["//visibility:public"], + deps = [":classifications_to_render_data_calculator_proto"], +) + +cc_library( + name = "classifications_to_render_data_calculator", + srcs = ["classifications_to_render_data_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":classifications_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + mediapipe_cc_proto_library( name = "landmarks_to_render_data_calculator_cc_proto", srcs = ["landmarks_to_render_data_calculator.proto"], @@ -615,6 +656,28 @@ cc_test( ], ) +cc_test( + name = "classifications_to_render_data_calculator_test", + size = "small", + srcs = ["classifications_to_render_data_calculator_test.cc"], + deps = [ + ":classifications_to_render_data_calculator", + ":classifications_to_render_data_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/deps:message_matchers", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], diff --git a/mediapipe/calculators/util/classifications_to_render_data_calculator.cc b/mediapipe/calculators/util/classifications_to_render_data_calculator.cc new file mode 100644 index 000000000..e1f6102bc --- /dev/null +++ b/mediapipe/calculators/util/classifications_to_render_data_calculator.cc @@ -0,0 +1,179 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "mediapipe/calculators/util/classifications_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" +namespace mediapipe { + +namespace { + +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kRenderDataTag[] = "RENDER_DATA"; + +constexpr char kSceneLabelLabel[] = "LABEL"; + +} // namespace + +// A calculator that converts Classification proto to RenderData proto for +// visualization. +// +// ClassificationList is the format for encoding one or more classifications of +// an image. +// +// The text(s) for "score label(_id)" will be shown starting on top left image +// corner. +// +// Example config: +// node { +// calculator: "ClassificationsToRenderDataCalculator" +// input_stream: "CLASSIFICATIONS:classifications" +// output_stream: "RENDER_DATA:render_data" +// options { +// [ClassificationsToRenderDataCalculatorOptions.ext] { +// text_delimiter: " <- " +// thickness: 2.0 +// color { r: 0 g: 0 b: 255 } +// text: { font_height: 2.0 } +// } +// } +// } +class ClassificationsToRenderDataCalculator : public CalculatorBase { + public: + ClassificationsToRenderDataCalculator() {} + ~ClassificationsToRenderDataCalculator() override {} + ClassificationsToRenderDataCalculator( + const ClassificationsToRenderDataCalculator&) = delete; + ClassificationsToRenderDataCalculator& operator=( + const ClassificationsToRenderDataCalculator&) = delete; + + static ::mediapipe::Status GetContract(CalculatorContract* cc); + + ::mediapipe::Status Open(CalculatorContext* cc) override; + + ::mediapipe::Status Process(CalculatorContext* cc) override; + + private: + // These utility methods are supposed to be used only by this class. No + // external client should depend on them. Due to C++ style guide unnamed + // namespace should not be used in header files. So, these has been defined + // as private static methods. + static void SetRenderAnnotationColorThickness( + const ClassificationsToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation); + + static void SetTextCoordinate(bool normalized, double left, double baseline, + RenderAnnotation::Text* text); + + static void AddLabel( + int ith, const Classification& classification, + const ClassificationsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data); +}; +REGISTER_CALCULATOR(ClassificationsToRenderDataCalculator); + +::mediapipe::Status ClassificationsToRenderDataCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kClassificationsTag)); + cc->Inputs().Tag(kClassificationsTag).Set(); + cc->Outputs().Tag(kRenderDataTag).Set(); + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ClassificationsToRenderDataCalculator::Open( + CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + return ::mediapipe::OkStatus(); +} + +::mediapipe::Status ClassificationsToRenderDataCalculator::Process( + CalculatorContext* cc) { + const auto& classifications = cc->Inputs() + .Tag(kClassificationsTag) + .Get() + .classification(); + if (classifications.empty()) { + return ::mediapipe::OkStatus(); + } + + const auto& options = + cc->Options(); + + auto render_data = absl::make_unique(); + render_data->set_scene_class(options.scene_class()); + + auto text_line_height = + (options.text().font_height() / (double)classifications.size()) / 10; + + int ith = 0; + for (const auto& classification : classifications) { + AddLabel(ith++, classification, options, text_line_height, + render_data.get()); + } + + cc->Outputs() + .Tag(kRenderDataTag) + .Add(render_data.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); +} + +void ClassificationsToRenderDataCalculator::SetRenderAnnotationColorThickness( + const ClassificationsToRenderDataCalculatorOptions& options, + RenderAnnotation* render_annotation) { + render_annotation->mutable_color()->set_r(options.color().r()); + render_annotation->mutable_color()->set_g(options.color().g()); + render_annotation->mutable_color()->set_b(options.color().b()); + render_annotation->set_thickness(options.thickness()); +} + +void ClassificationsToRenderDataCalculator::SetTextCoordinate( + bool normalized, double left, double baseline, + RenderAnnotation::Text* text) { + text->set_normalized(normalized); + text->set_left(normalized ? std::max(left, 0.0) : left); + // Normalized coordinates must be between 0.0 and 1.0, if they are used. + text->set_baseline(normalized ? std::min(baseline, 1.0) : baseline); +} + +void ClassificationsToRenderDataCalculator::AddLabel( + int ith, const Classification& classification, + const ClassificationsToRenderDataCalculatorOptions& options, + float text_line_height, RenderData* render_data) { + std::string label = classification.label(); + if (label.empty()) { + label = absl::StrCat("index=", classification.index()); + } + std::string score_and_label = + absl::StrCat(classification.score(), options.text_delimiter(), label); + + // Add the render annotations for "score label" + auto* label_annotation = render_data->add_render_annotations(); + label_annotation->set_scene_tag(kSceneLabelLabel); + SetRenderAnnotationColorThickness(options, label_annotation); + auto* text = label_annotation->mutable_text(); + *text = options.text(); + text->set_display_text(score_and_label); + text->set_font_height(text_line_height); + SetTextCoordinate(true, 0.0, 0.0 + (ith + 1) * text_line_height, text); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/classifications_to_render_data_calculator.proto b/mediapipe/calculators/util/classifications_to_render_data_calculator.proto new file mode 100644 index 000000000..091820c61 --- /dev/null +++ b/mediapipe/calculators/util/classifications_to_render_data_calculator.proto @@ -0,0 +1,45 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; +import "mediapipe/util/render_data.proto"; + +message ClassificationsToRenderDataCalculatorOptions { + extend CalculatorOptions { + optional ClassificationsToRenderDataCalculatorOptions ext = 299999999; + } + + // The delimiter to separate score and label(_id). + optional string text_delimiter = 2 [default = " "]; + + // Rendering options for the label. + optional RenderAnnotation.Text text = 4; + + // Thickness for drawing the score(s) and label(s). + optional double thickness = 5 [default = 1.0]; + + // Color for drawing the score(s) and label(s). + optional Color color = 6; + + // An optional string that identifies this class of annotations + // for the render data output this calculator produces. If multiple + // instances of this calculator are present in the graph, this value + // should be unique among them. + optional string scene_class = 7 [default = "CLASSIFICATION"]; +} diff --git a/mediapipe/calculators/util/classifications_to_render_data_calculator_test.cc b/mediapipe/calculators/util/classifications_to_render_data_calculator_test.cc new file mode 100644 index 000000000..808ef25a8 --- /dev/null +++ b/mediapipe/calculators/util/classifications_to_render_data_calculator_test.cc @@ -0,0 +1,70 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/classifications_to_render_data_calculator.pb.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/message_matchers.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { + +Classification CreateClassification(int32 index, float score, + const std::string& label) { + Classification classification; + classification.set_score(score); + classification.set_index(index); + classification.set_label(label); + return classification; +} + +TEST(ClassificationsToRenderDataCalculatorTest, OnlyClassificationList) { + CalculatorRunner runner(ParseTextProtoOrDie(R"( + calculator: "ClassificationsToRenderDataCalculator" + input_stream: "CLASSIFICATIONS:classifications" + output_stream: "RENDER_DATA:render_data" + )")); + + auto classifications(absl::make_unique()); + *(classifications->add_classification()) = + CreateClassification(0, 0.9, "zeroth_label"); + *(classifications->add_classification()) = CreateClassification(1, 0.3, ""); + + runner.MutableInputs() + ->Tag("CLASSIFICATIONS") + .packets.push_back( + Adopt(classifications.release()).At(Timestamp::PostStream())); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = + runner.Outputs().Tag("RENDER_DATA").packets; + ASSERT_EQ(1, output.size()); + const auto& actual = output[0].Get(); + EXPECT_EQ(actual.render_annotations_size(), 2); + // Labels + EXPECT_EQ(actual.render_annotations(0).text().display_text(), + "0.9 zeroth_label"); + EXPECT_EQ(actual.render_annotations(1).text().display_text(), "0.3 index=1"); +} + +} // namespace mediapipe