diff --git a/Dockerfile b/Dockerfile index eb983af53..462dacbd4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,8 +51,7 @@ RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /u RUN pip3 install --upgrade setuptools RUN pip3 install wheel RUN pip3 install future -RUN pip3 install absl-py -RUN pip3 install numpy +RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install six==1.14.0 RUN pip3 install tensorflow==2.2.0 RUN pip3 install tf_slim diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png index d28effc39..8e4073050 100644 Binary files a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png and b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png differ diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png new file mode 100644 index 000000000..d28effc39 Binary files /dev/null and b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png differ diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json index 8ae934c76..3ed9f5238 100644 --- a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json +++ b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -84,6 +84,7 @@ { "idiom" : "ipad", "size" : "76x76", + "filename" : "76_c_Ipad_2x.png", "scale" : "2x" }, { diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto index 1ecf8e072..9dd65a265 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto @@ -25,7 +25,7 @@ message AudioClassifierOptions { extend mediapipe.CalculatorOptions { optional AudioClassifierOptions ext = 451755788; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index c8985c98b..8b553dea4 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -43,3 +43,73 @@ cc_library( ], alwayslink = 1, ) + +mediapipe_proto_library( + name = "score_calibration_calculator_proto", + srcs = ["score_calibration_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "score_calibration_calculator", + srcs = ["score_calibration_calculator.cc"], + deps = [ + ":score_calibration_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "score_calibration_calculator_test", + srcs = ["score_calibration_calculator_test.cc"], + deps = [ + ":score_calibration_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "score_calibration_utils", + srcs = ["score_calibration_utils.cc"], + hdrs = ["score_calibration_utils.h"], + deps = [ + ":score_calibration_calculator_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "score_calibration_utils_test", + srcs = ["score_calibration_utils_test.cc"], + deps = [ + ":score_calibration_calculator_cc_proto", + ":score_calibration_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc new file mode 100644 index 000000000..c689cc255 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc @@ -0,0 +1,259 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" + +namespace mediapipe { +namespace api2 { + +using ::absl::StatusCode; +using ::mediapipe::tasks::CreateStatusWithPayload; +using ::mediapipe::tasks::MediaPipeTasksStatus; +using ::mediapipe::tasks::ScoreCalibrationCalculatorOptions; + +namespace { +// Used to prevent log(<=0.0) in ClampedLog() calls. +constexpr float kLogScoreMinimum = 1e-16; + +// Returns the following, depending on x: +// x => threshold: log(x) +// x < threshold: 2 * log(thresh) - log(2 * thresh - x) +// This form (a) is anti-symmetric about the threshold and (b) has continuous +// value and first derivative. This is done to prevent taking the log of values +// close to 0 which can lead to floating point errors and is better than simple +// clamping since it preserves order for scores less than the threshold. +float ClampedLog(float x, float threshold) { + if (x < threshold) { + return 2.0 * std::log(static_cast(threshold)) - + log(2.0 * threshold - x); + } + return std::log(static_cast(x)); +} +} // namespace + +// Applies score calibration to a tensor of score predictions, typically applied +// to the output of a classification or object detection model. +// +// See corresponding options for more details on the score calibration +// parameters and formula. +// +// Inputs: +// SCORES - std::vector +// A vector containing a single Tensor `x` of type kFloat32, representing +// the scores to calibrate. By default (i.e. if INDICES is not connected), +// x[i] will be calibrated using the sigmoid provided at index i in the +// options. +// INDICES - std::vector @Optional +// An optional vector containing a single Tensor `y` of type kFloat32 and +// same size as `x`. If provided, x[i] will be calibrated using the sigmoid +// provided at index y[i] (casted as an integer) in the options. `x` and `y` +// must contain the same number of elements. Typically used for object +// detection models. +// +// Outputs: +// CALIBRATED_SCORES - std::vector +// A vector containing a single Tensor of type kFloat32 and of the same size +// as the input tensors. Contains the output calibrated scores. +class ScoreCalibrationCalculator : public Node { + public: + static constexpr Input> kScoresIn{"SCORES"}; + static constexpr Input>::Optional kIndicesIn{"INDICES"}; + static constexpr Output> kScoresOut{"CALIBRATED_SCORES"}; + MEDIAPIPE_NODE_CONTRACT(kScoresIn, kIndicesIn, kScoresOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + ScoreCalibrationCalculatorOptions options_; + std::function score_transformation_; + + // Computes the calibrated score for the provided index. Does not check for + // out-of-bounds index. + float ComputeCalibratedScore(int index, float score); + // Same as above, but does check for out-of-bounds index. + absl::StatusOr SafeComputeCalibratedScore(int index, float score); +}; + +absl::Status ScoreCalibrationCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + // Sanity checks. + if (options_.sigmoids_size() == 0) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected at least one sigmoid, found none.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + for (const auto& sigmoid : options_.sigmoids()) { + if (sigmoid.has_scale() && sigmoid.scale() < 0.0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("The scale parameter of the sigmoids must be " + "positive, found %f.", + sigmoid.scale()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + } + // Set score transformation function once and for all. + switch (options_.score_transformation()) { + case tasks::ScoreCalibrationCalculatorOptions::IDENTITY: + score_transformation_ = [](float x) { return x; }; + break; + case tasks::ScoreCalibrationCalculatorOptions::LOG: + score_transformation_ = [](float x) { + return ClampedLog(x, kLogScoreMinimum); + }; + break; + case tasks::ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC: + score_transformation_ = [](float x) { + return (ClampedLog(x, kLogScoreMinimum) - + ClampedLog(1.0 - x, kLogScoreMinimum)); + }; + break; + default: + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Unsupported ScoreTransformation type: %s", + ScoreCalibrationCalculatorOptions::ScoreTransformation_Name( + options_.score_transformation())), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status ScoreCalibrationCalculator::Process(CalculatorContext* cc) { + RET_CHECK_EQ(kScoresIn(cc)->size(), 1); + const auto& scores = (*kScoresIn(cc))[0]; + RET_CHECK(scores.element_type() == Tensor::ElementType::kFloat32); + auto scores_view = scores.GetCpuReadView(); + const float* raw_scores = scores_view.buffer(); + int num_scores = scores.shape().num_elements(); + + auto output_tensors = std::make_unique>(); + output_tensors->reserve(1); + output_tensors->emplace_back(scores.element_type(), scores.shape()); + auto calibrated_scores = &output_tensors->back(); + auto calibrated_scores_view = calibrated_scores->GetCpuWriteView(); + float* raw_calibrated_scores = calibrated_scores_view.buffer(); + + if (kIndicesIn(cc).IsConnected()) { + RET_CHECK_EQ(kIndicesIn(cc)->size(), 1); + const auto& indices = (*kIndicesIn(cc))[0]; + RET_CHECK(indices.element_type() == Tensor::ElementType::kFloat32); + if (num_scores != indices.shape().num_elements()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of elements in the input " + "scores tensor (%d) and indices tensor (%d).", + num_scores, indices.shape().num_elements()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + auto indices_view = indices.GetCpuReadView(); + const float* raw_indices = indices_view.buffer(); + for (int i = 0; i < num_scores; ++i) { + // Use the "safe" flavor as we need to check that the externally provided + // indices are not out-of-bounds. + ASSIGN_OR_RETURN(raw_calibrated_scores[i], + SafeComputeCalibratedScore( + static_cast(raw_indices[i]), raw_scores[i])); + } + } else { + if (num_scores != options_.sigmoids_size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of sigmoids (%d) and number " + "of elements in the input scores tensor (%d).", + options_.sigmoids_size(), num_scores), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + for (int i = 0; i < num_scores; ++i) { + // Use the "unsafe" flavor as we have already checked for out-of-bounds + // issues. + raw_calibrated_scores[i] = ComputeCalibratedScore(i, raw_scores[i]); + } + } + kScoresOut(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +float ScoreCalibrationCalculator::ComputeCalibratedScore(int index, + float score) { + const auto& sigmoid = options_.sigmoids(index); + + bool is_empty = + !sigmoid.has_scale() || !sigmoid.has_offset() || !sigmoid.has_slope(); + bool is_below_min_score = + sigmoid.has_min_score() && score < sigmoid.min_score(); + if (is_empty || is_below_min_score) { + return options_.default_score(); + } + + float transformed_score = score_transformation_(score); + float scale_shifted_score = + transformed_score * sigmoid.slope() + sigmoid.offset(); + // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0 + // and exp(x) / (1+exp(x)) when scale_shifted_score < 0. + float calibrated_score; + if (scale_shifted_score >= 0.0) { + calibrated_score = + sigmoid.scale() / + (1.0 + std::exp(static_cast(-scale_shifted_score))); + } else { + float score_exp = std::exp(static_cast(scale_shifted_score)); + calibrated_score = sigmoid.scale() * score_exp / (1.0 + score_exp); + } + // Scale is non-negative (checked in SigmoidFromLabelAndLine), + // thus calibrated_score should be in the range of [0, scale]. However, due to + // numberical stability issue, it may fall out of the boundary. Cap the value + // to [0, scale] instead. + return std::max(std::min(calibrated_score, sigmoid.scale()), 0.0f); +} + +absl::StatusOr ScoreCalibrationCalculator::SafeComputeCalibratedScore( + int index, float score) { + if (index < 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected positive indices, found %d.", index), + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (index > options_.sigmoids_size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Unable to get score calibration parameters for index " + "%d : only %d sigmoids were provided.", + index, options_.sigmoids_size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return ComputeCalibratedScore(index, score); +} + +MEDIAPIPE_REGISTER_NODE(ScoreCalibrationCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto new file mode 100644 index 000000000..11d944c93 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto @@ -0,0 +1,67 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks; + +import "mediapipe/framework/calculator.proto"; + +message ScoreCalibrationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional ScoreCalibrationCalculatorOptions ext = 470204318; + } + + // Score calibration parameters for one individual category. The formula used + // to transform the uncalibrated score `x` is: + // * `f(x) = scale / (1 + e^-(slope * g(x) + offset))` if `x > min_score` or + // if no min_score has been specified, + // * `f(x) = default_score` otherwise or if no scale, slope or offset have + // been specified. + // + // Where: + // * scale must be positive, + // * g(x) is a global (i.e. category independent) transform specified using + // the score_transformation field, + // * default_score is a global parameter defined below. + // + // There should be exactly one sigmoid per number of supported output + // categories in the model, with either: + // * no fields set, + // * scale, slope and offset set, + // * all fields set. + message Sigmoid { + optional float scale = 1; + optional float slope = 2; + optional float offset = 3; + optional float min_score = 4; + } + repeated Sigmoid sigmoids = 1; + + // Score transformation that defines the `g(x)` function in the above formula. + enum ScoreTransformation { + UNSPECIFIED = 0; + // g(x) = x. + IDENTITY = 1; + // g(x) = log(x). + LOG = 2; + // g(x) = log(x) - log(1 - x). + INVERSE_LOGISTIC = 3; + } + optional ScoreTransformation score_transformation = 2 [default = IDENTITY]; + + // Default score. + optional float default_score = 3; +} diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc new file mode 100644 index 000000000..8134d86d2 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc @@ -0,0 +1,309 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +// Builds the graph and feeds inputs. +void BuildGraph(CalculatorRunner* runner, std::vector scores, + std::optional> indices = std::nullopt) { + auto scores_tensors = std::make_unique>(); + scores_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, static_cast(scores.size())}); + auto scores_view = scores_tensors->back().GetCpuWriteView(); + float* scores_buffer = scores_view.buffer(); + ASSERT_NE(scores_buffer, nullptr); + for (int i = 0; i < scores.size(); ++i) { + scores_buffer[i] = scores[i]; + } + auto& input_scores_packets = runner->MutableInputs()->Tag("SCORES").packets; + input_scores_packets.push_back( + mediapipe::Adopt(scores_tensors.release()).At(mediapipe::Timestamp(0))); + + if (indices.has_value()) { + auto indices_tensors = std::make_unique>(); + indices_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, static_cast(indices->size())}); + auto indices_view = indices_tensors->back().GetCpuWriteView(); + float* indices_buffer = indices_view.buffer(); + ASSERT_NE(indices_buffer, nullptr); + for (int i = 0; i < indices->size(); ++i) { + indices_buffer[i] = static_cast((*indices)[i]); + } + auto& input_indices_packets = + runner->MutableInputs()->Tag("INDICES").packets; + input_indices_packets.push_back(mediapipe::Adopt(indices_tensors.release()) + .At(mediapipe::Timestamp(0))); + } +} + +// Compares the provided tensor contents with the expected values. +void ValidateResult(const Tensor& actual, const std::vector& expected) { + EXPECT_EQ(actual.element_type(), Tensor::ElementType::kFloat32); + EXPECT_EQ(expected.size(), actual.shape().num_elements()); + auto view = actual.GetCpuReadView(); + auto buffer = view.buffer(); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_FLOAT_EQ(expected[i], buffer[i]); + } +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNoSigmoid) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {} + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected at least one sigmoid, found none")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeScale) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { slope: 1 offset: 1 scale: -1 } + } + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("The scale parameter of the sigmoids must be positive")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithUnspecifiedTransformation) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { slope: 1 offset: 1 scale: 1 } + score_transformation: UNSPECIFIED + } + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Unsupported ScoreTransformation type")); +} + +// Struct holding the parameters for parameterized tests below. +struct CalibrationTestParams { + // The score transformation to apply. + std::string score_transformation; + // The expected results. + std::vector expected_results; +}; + +class CalibrationWithoutIndicesTest + : public TestWithParam {}; + +TEST_P(CalibrationWithoutIndicesTest, Succeeds) { + CalculatorRunner runner(ParseTextProtoOrDie(absl::StrFormat( + R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: %s + default_score: 0.2 + } + } + )pb", + GetParam().score_transformation))); + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}); + MP_ASSERT_OK(runner.Run()); + const Tensor& results = runner.Outputs() + .Get("CALIBRATED_SCORES", 0) + .packets[0] + .Get>()[0]; + + ValidateResult(results, GetParam().expected_results); +} + +INSTANTIATE_TEST_SUITE_P( + ScoreCalibrationCalculatorTest, CalibrationWithoutIndicesTest, + Values(CalibrationTestParams{.score_transformation = "IDENTITY", + .expected_results = {0.4948505976, + 0.5059588508, 0.2, 0.2}}, + CalibrationTestParams{ + .score_transformation = "LOG", + .expected_results = {0.2976901255, 0.3393665735, 0.2, 0.2}}, + CalibrationTestParams{ + .score_transformation = "INVERSE_LOGISTIC", + .expected_results = {0.3203217641, 0.3778080605, 0.2, 0.2}}), + [](const TestParamInfo& info) { + return info.param.score_transformation; + }); + +TEST(ScoreCalibrationCalculatorTest, FailsWithMissingSigmoids) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: LOG + default_score: 0.2 + } + } + )pb")); + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5, 0.6}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Mismatch between number of sigmoids")); +} + +TEST(ScoreCalibrationCalculatorTest, SucceedsWithIndices) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {1, 2, 3, 0}; + + BuildGraph(&runner, {0.3, 0.4, 0.5, 0.2}, indices); + MP_ASSERT_OK(runner.Run()); + const Tensor& results = runner.Outputs() + .Get("CALIBRATED_SCORES", 0) + .packets[0] + .Get>()[0]; + ValidateResult(results, {0.5059588508, 0.2, 0.2, 0.4948505976}); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeIndex) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {0, 1, 2, -1}; + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("Expected positive indices")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithOutOfBoundsIndex) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {0, 1, 5, 3}; + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("Unable to get score calibration parameters for index")); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc new file mode 100644 index 000000000..120344be6 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc @@ -0,0 +1,115 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { + +namespace { +// Converts ScoreTransformation type from TFLite Metadata to calculator options. +ScoreCalibrationCalculatorOptions::ScoreTransformation +ConvertScoreTransformationType(tflite::ScoreTransformationType type) { + switch (type) { + case tflite::ScoreTransformationType_IDENTITY: + return ScoreCalibrationCalculatorOptions::IDENTITY; + case tflite::ScoreTransformationType_LOG: + return ScoreCalibrationCalculatorOptions::LOG; + case tflite::ScoreTransformationType_INVERSE_LOGISTIC: + return ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC; + } +} + +// Parses a single line of the score calibration file into the provided sigmoid. +absl::Status FillSigmoidFromLine( + absl::string_view line, + ScoreCalibrationCalculatorOptions::Sigmoid* sigmoid) { + if (line.empty()) { + return absl::OkStatus(); + } + std::vector str_params = absl::StrSplit(line, ','); + if (str_params.size() != 3 && str_params.size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected 3 or 4 parameters per line in score " + "calibration file, got %d.", + str_params.size()), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + std::vector params(str_params.size()); + for (int i = 0; i < str_params.size(); ++i) { + if (!absl::SimpleAtof(str_params[i], ¶ms[i])) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Could not parse score calibration parameter as float: %s.", + str_params[i]), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + } + if (params[0] < 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "The scale parameter of the sigmoids must be positive, found %f.", + params[0]), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + sigmoid->set_scale(params[0]); + sigmoid->set_slope(params[1]); + sigmoid->set_offset(params[2]); + if (params.size() == 4) { + sigmoid->set_min_score(params[3]); + } + return absl::OkStatus(); +} +} // namespace + +absl::Status ConfigureScoreCalibration( + tflite::ScoreTransformationType score_transformation, float default_score, + absl::string_view score_calibration_file, + ScoreCalibrationCalculatorOptions* calculator_options) { + calculator_options->set_score_transformation( + ConvertScoreTransformationType(score_transformation)); + calculator_options->set_default_score(default_score); + + if (score_calibration_file.empty()) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Expected non-empty score calibration file.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + std::vector lines = + absl::StrSplit(score_calibration_file, '\n'); + for (const auto& line : lines) { + auto* sigmoid = calculator_options->add_sigmoids(); + MP_RETURN_IF_ERROR(FillSigmoidFromLine(line, sigmoid)); + } + + return absl::OkStatus(); +} + +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h new file mode 100644 index 000000000..5c3d446ee --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h @@ -0,0 +1,38 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { + +// Populates ScoreCalibrationCalculatorOptions given a TFLite Metadata score +// transformation type, default threshold and score calibration AssociatedFile +// contents, as specified in `TENSOR_AXIS_SCORE_CALIBRATION`: +// https://github.com/google/mediapipe/blob/master/mediapipe/tasks/metadata/metadata_schema.fbs +absl::Status ConfigureScoreCalibration( + tflite::ScoreTransformationType score_transformation, float default_score, + absl::string_view score_calibration_file, + ScoreCalibrationCalculatorOptions* options); + +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc new file mode 100644 index 000000000..dc7fd90cd --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc @@ -0,0 +1,130 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace { + +using ::testing::HasSubstr; + +TEST(ConfigureScoreCalibrationTest, SucceedsWithoutTrailingNewline) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7"); + + MP_ASSERT_OK(ConfigureScoreCalibration( + tflite::ScoreTransformationType_IDENTITY, + /*default_score=*/0.5, score_calibration_file, &options)); + + EXPECT_THAT( + options, + EqualsProto(ParseTextProtoOrDie(R"pb( + score_transformation: IDENTITY + default_score: 0.5 + sigmoids {} + sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 } + sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 } + )pb"))); +} + +TEST(ConfigureScoreCalibrationTest, SucceedsWithTrailingNewline) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7\n"); + + MP_ASSERT_OK(ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options)); + + EXPECT_THAT( + options, + EqualsProto(ParseTextProtoOrDie(R"pb( + score_transformation: LOG + default_score: 0.5 + sigmoids {} + sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 } + sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 } + sigmoids {} + )pb"))); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithEmptyFile) { + ScoreCalibrationCalculatorOptions options; + + auto status = + ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + /*score_calibration_file=*/"", &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected non-empty score calibration file")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithInvalidNumParameters) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = absl::StrCat("0.1,0.2,0.3\n", "0.1,0.2"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected 3 or 4 parameters per line")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithNonParseableParameter) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("0.1,0.2,0.3\n", "0.1,foo,0.3\n"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("Could not parse score calibration parameter as float")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithNegativeScaleParameter) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("0.1,0.2,0.3\n", "-0.1,0.2,0.3\n"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("The scale parameter of the sigmoids must be positive")); +} + +} // namespace +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 7e20d8ef4..8a219bb80 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -18,8 +18,18 @@ limitations under the License. #include #include #include + +#ifdef ABSL_HAVE_MMAP #include +#endif + +#ifdef _WIN32 +#include +#include +#include +#else #include +#endif #include #include @@ -44,12 +54,17 @@ using ::absl::StatusCode; // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). int64 GetPageSizeAlignedOffset(int64 offset) { +#ifdef _WIN32 + // mmap is not used on Windows + return -1; +#else int64 aligned_offset = offset; int64 page_size = sysconf(_SC_PAGE_SIZE); if (offset % page_size != 0) { aligned_offset = offset / page_size * page_size; } return aligned_offset; +#endif } } // namespace @@ -69,6 +84,12 @@ ExternalFileHandler::CreateFromExternalFile( } absl::Status ExternalFileHandler::MapExternalFile() { +// TODO: Add Windows support +#ifdef _WIN32 + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + "File loading is not yet supported on Windows", + MediaPipeTasksStatus::kFileReadError); +#else if (!external_file_.file_content().empty()) { return absl::OkStatus(); } @@ -169,6 +190,7 @@ absl::Status ExternalFileHandler::MapExternalFile() { MediaPipeTasksStatus::kFileMmapError); } return absl::OkStatus(); +#endif } absl::string_view ExternalFileHandler::GetFileContent() { @@ -182,9 +204,11 @@ absl::string_view ExternalFileHandler::GetFileContent() { } ExternalFileHandler::~ExternalFileHandler() { +#ifndef _WIN32 if (buffer_ != MAP_FAILED) { munmap(buffer_, buffer_aligned_size_); } +#endif if (owned_fd_ >= 0) { close(owned_fd_); } diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index 6393c822e..42f2bbc85 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +// TODO Refactor naming and class structure of hand related Tasks. syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; diff --git a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto index 3de64e593..a2cfc7eaf 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto @@ -24,7 +24,7 @@ message HandLandmarkDetectorOptions { extend mediapipe.CalculatorOptions { optional HandLandmarkDetectorOptions ext = 462713202; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto b/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto index 1fa221179..21fb3cd8c 100644 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto +++ b/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto @@ -25,7 +25,7 @@ message ImageClassifierOptions { extend mediapipe.CalculatorOptions { optional ImageClassifierOptions ext = 456383383; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/segmentation/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD similarity index 81% rename from mediapipe/tasks/cc/vision/segmentation/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/BUILD index cc4d8236f..cb0482e42 100644 --- a/mediapipe/tasks/cc/vision/segmentation/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -12,34 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_segmenter_options_proto", - srcs = ["image_segmenter_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:segmenter_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", - ], -) - cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], deps = [ ":image_segmenter_graph", - ":image_segmenter_options_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", @@ -51,7 +42,6 @@ cc_library( name = "image_segmenter_graph", srcs = ["image_segmenter_graph.cc"], deps = [ - ":image_segmenter_options_cc_proto", "//mediapipe/calculators/core:merge_to_vector_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", @@ -70,6 +60,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", @@ -82,9 +73,9 @@ cc_library( ) cc_library( - name = "custom_op_resolvers", - srcs = ["custom_op_resolvers.cc"], - hdrs = ["custom_op_resolvers.h"], + name = "image_segmenter_op_resolvers", + srcs = ["image_segmenter_op_resolvers.cc"], + hdrs = ["image_segmenter_op_resolvers.h"], deps = [ "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc new file mode 100644 index 000000000..090149d92 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -0,0 +1,134 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace { + +constexpr char kSegmentationStreamName[] = "segmented_mask_out"; +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.vision.ImageSegmenterGraph"; + +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Image; +using ImageSegmenterOptionsProto = + image_segmenter::proto::ImageSegmenterOptions; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// "mediapipe.tasks.vision.ImageSegmenterGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& task_subgraph = graph.AddNode(kSubgraphTypeName); + task_subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> + graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); + } + graph.In(kImageTag) >> task_subgraph.In(kImageTag); + return graph.GetConfig(); +} + +// Converts the user-facing ImageSegmenterOptions struct to the internal +// ImageSegmenterOptions proto. +std::unique_ptr ConvertImageSegmenterOptionsToProto( + ImageSegmenterOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + options_proto->set_display_names_locale(options->display_names_locale); + switch (options->output_type) { + case ImageSegmenterOptions::OutputType::CATEGORY_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CATEGORY_MASK); + break; + case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CONFIDENCE_MASK); + break; + } + switch (options->activation) { + case ImageSegmenterOptions::Activation::NONE: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::NONE); + break; + case ImageSegmenterOptions::Activation::SIGMOID: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::SIGMOID); + break; + case ImageSegmenterOptions::Activation::SOFTMAX: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::SOFTMAX); + break; + } + return options_proto; +} + +} // namespace + +absl::StatusOr> ImageSegmenter::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr> ImageSegmenter::Segment( + mediapipe::Image image) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData({{kImageInStreamName, + mediapipe::MakePacket(std::move(image))}})); + return output_packets[kSegmentationStreamName].Get>(); +} + +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h new file mode 100644 index 000000000..00c63953a --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -0,0 +1,123 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +// The options for configuring a mediapipe image segmenter task. +struct ImageSegmenterOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Image segmenter has three running modes: + // 1) The image mode for segmenting image on single image inputs. + // 2) The video mode for segmenting image on the decoded frames of a video. + // 3) The live stream mode for segmenting image on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the segmentation results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + std::string display_names_locale = "en"; + + // The output type of segmentation results. + enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK = 0, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK = 1, + }; + + OutputType output_type = OutputType::CATEGORY_MASK; + + // The activation function used on the raw segmentation model output. + enum Activation { + NONE = 0, // No activation function is used. + SIGMOID = 1, // Assumes 1-channel input tensor. + SOFTMAX = 2, // Assumes multi-channel input tensor. + }; + + Activation activation = Activation::NONE; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function>, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs segmentation on images. +// +// The API expects a TFLite model with mandatory TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - RGB and greyscale inputs are supported (`channels` is required to be +// 1 or 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensors: +// (kTfLiteUInt8/kTfLiteFloat32) +// - list of segmented masks. +// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. +// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size +// `cahnnels`. +// - batch is always 1 +// An example of such model can be found at: +// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 +class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageSegmenter from the provided options. A non-default + // OpResolver can be specified in the BaseOptions of ImageSegmenterOptions, + // to support custom Ops of the segmentation model. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Runs the actual segmentation task. + absl::StatusOr> Segment(mediapipe::Image image); +}; + +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc similarity index 84% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index b960fd930..d843689e2 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -53,6 +53,7 @@ using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; @@ -63,6 +64,14 @@ constexpr char kImageTag[] = "IMAGE"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +// Struct holding the different output streams produced by the image segmenter +// subgraph. +struct ImageSegmenterOutputs { + std::vector> segmented_masks; + // The same as the input image, mainly used for live stream mode. + Source image; +}; + } // namespace absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) { @@ -140,6 +149,10 @@ absl::StatusOr GetOutputTensor( // An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic // segmentation. +// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. +// Users can retrieve segmented mask of only particular category/channel from +// SEGMENTATION, and users can also get all segmented masks from +// GROUPED_SEGMENTATION. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -147,8 +160,13 @@ absl::StatusOr GetOutputTensor( // Image to perform segmentation on. // // Outputs: -// SEGMENTATION - SEGMENTATION -// Segmented masks. +// SEGMENTATION - mediapipe::Image @Multiple +// Segmented masks for individual category. Segmented mask of single +// category can be accessed by index based output stream. +// GROUPED_SEGMENTATION - std::vector +// The output segmented masks grouped in a vector. +// IMAGE - mediapipe::Image +// The image that image segmenter runs on. // // Example: // node { @@ -156,7 +174,8 @@ absl::StatusOr GetOutputTensor( // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { -// [mediapipe.tasks.ImageSegmenterOptions.ext] { +// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext] +// { // segmenter_options { // output_type: CONFIDENCE_MASK // activation: SOFTMAX @@ -171,20 +190,22 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto segmentations, + ASSIGN_OR_RETURN(auto output_streams, BuildSegmentationTask( sc->Options(), *model_resources, graph[Input(kImageTag)], graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); - for (int i = 0; i < segmentations.size(); ++i) { - segmentations[i] >> merge_images_to_vector[Input::Multiple("")][i]; - segmentations[i] >> graph[Output::Multiple(kSegmentationTag)][i]; + for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { + output_streams.segmented_masks[i] >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.segmented_masks[i] >> + graph[Output::Multiple(kSegmentationTag)][i]; } merge_images_to_vector.Out("") >> graph[Output>(kGroupedSegmentationTag)]; - + output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -193,12 +214,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. // - // task_options: the mediapipe tasks ImageSegmenterOptions. + // task_options: the mediapipe tasks ImageSegmenterOptions proto. // model_resources: the ModelSources object initialized from a segmentation // model file with model metadata. // image_in: (mediapipe::Image) stream to run segmentation on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr>> BuildSegmentationTask( + absl::StatusOr BuildSegmentationTask( const ImageSegmenterOptions& task_options, const core::ModelResources& model_resources, Source image_in, Graph& graph) { @@ -246,7 +267,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } - return segmented_masks; + return {{ + .segmented_masks = segmented_masks, + .image = preprocessing[Output(kImageTag)], + }}; } }; diff --git a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc similarity index 96% rename from mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc index b24b426ad..cd3b5690f 100644 --- a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" diff --git a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h similarity index 81% rename from mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h index 2b185d792..a0538a674 100644 --- a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ #include "tensorflow/lite/kernels/register.h" @@ -34,4 +34,4 @@ class SelfieSegmentationModelOpResolver } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc similarity index 80% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index f1e8ee4f7..f43d28fca 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include #include @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" @@ -46,11 +46,8 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::ImageSegmenterOptions; -using ::mediapipe::tasks::SegmenterOptions; using ::testing::HasSubstr; using ::testing::Optional; -using ::tflite::ops::builtin::BuiltinOpResolver; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kDeeplabV3WithMetadata[] = "deeplabv3.tflite"; @@ -167,19 +164,19 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - MP_ASSERT_OK(ImageSegmenter::Create(std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->base_options.op_resolver = absl::make_unique(); + MP_ASSERT_OK(ImageSegmenter::Create(std::move(options))); } TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - - auto segmenter_or = ImageSegmenter::Create( - std::move(options), absl::make_unique()); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + auto segmenter_or = ImageSegmenter::Create(std::move(options)); // TODO: Make MediaPipe InferenceCalculator report the detailed // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); @@ -202,24 +199,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { MediaPipeTasksStatus::kRunnerInitializationError)))); } -TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::UNSPECIFIED); - - auto segmenter_or = ImageSegmenter::Create( - std::move(options), absl::make_unique()); - - EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); - EXPECT_THAT(segmenter_or.status().message(), - HasSubstr("`output_type` must not be UNSPECIFIED")); - EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - class SegmentationTest : public tflite_shims::testing::Test {}; TEST_F(SegmentationTest, SucceedsWithCategoryMask) { @@ -228,10 +207,10 @@ TEST_F(SegmentationTest, SucceedsWithCategoryMask) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); @@ -253,12 +232,11 @@ TEST_F(SegmentationTest, SucceedsWithConfidenceMask) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image)); @@ -281,17 +259,15 @@ TEST_F(SegmentationTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata)); - MP_ASSERT_OK_AND_ASSIGN( - std::unique_ptr segmenter, - ImageSegmenter::Create( - std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 2); @@ -313,15 +289,14 @@ TEST_F(SegmentationTest, SucceedsSelfie144x256Segmentations) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata)); - MP_ASSERT_OK_AND_ASSIGN( - std::unique_ptr segmenter, - ImageSegmenter::Create( - std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::NONE; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 1); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD new file mode 100644 index 000000000..b9b8ea436 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "image_segmenter_options_proto", + srcs = ["image_segmenter_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components:segmenter_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto similarity index 91% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto index ab8ff7c83..fcb2914cf 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/segmenter_options.proto"; @@ -25,7 +25,7 @@ message ImageSegmenterOptions { extend mediapipe.CalculatorOptions { optional ImageSegmenterOptions ext = 458105758; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 6f23e9b52..e98013223 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -36,10 +36,19 @@ namespace vision { // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; + // The running mode of the task. Default to the image mode. + // Object detector has three running modes: + // 1) The image mode for detecting objects on single image inputs. + // 2) The video mode for detecting objects on the decoded frames of a video. + // 3) The live stream mode for detecting objects on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + // The locale to use for display names specified through the TFLite Model // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; @@ -65,15 +74,6 @@ struct ObjectDetectorOptions { // category names are ignored. Mutually exclusive with category_allowlist. std::vector category_denylist = {}; - // The running mode of the task. Default to the image mode. - // Object detector has three running modes: - // 1) The image mode for detecting objects on single image inputs. - // 2) The video mode for detecting objects on the decoded frames of a video. - // 3) The live stream mode for detecting objects on the live stream of input - // data, such as from camera. In this mode, the "result_callback" below must - // be specified to receive the detection results asynchronously. - core::RunningMode running_mode = core::RunningMode::IMAGE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index 5e2955a9f..37edab1d9 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -27,7 +27,7 @@ message ObjectDetectorOptions { extend mediapipe.CalculatorOptions { optional ObjectDetectorOptions ext = 443442058; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc b/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc deleted file mode 100644 index efed5685f..000000000 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h" - -#include "mediapipe/framework/api2/builder.h" -#include "mediapipe/tasks/cc/core/task_api_factory.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -namespace { - -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; -constexpr char kImageStreamName[] = "image_in"; -constexpr char kImageTag[] = "IMAGE"; -constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageSegmenterGraph"; - -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Image; - -// Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.vision.SegmenterGraph". -CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options) { - api2::builder::Graph graph; - auto& subgraph = graph.AddNode(kSubgraphTypeName); - subgraph.GetOptions().Swap(options.get()); - graph.In(kImageTag).SetName(kImageStreamName) >> subgraph.In(kImageTag); - subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); - return graph.GetConfig(); -} - -} // namespace - -absl::StatusOr> ImageSegmenter::Create( - std::unique_ptr options, - std::unique_ptr resolver) { - return core::TaskApiFactory::Create( - CreateGraphConfig(std::move(options)), std::move(resolver)); -} - -absl::StatusOr> ImageSegmenter::Segment( - mediapipe::Image image) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat("GPU input images are currently not supported."), - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN( - auto output_packets, - runner_->Process({{kImageStreamName, - mediapipe::MakePacket(std::move(image))}})); - return output_packets[kSegmentationStreamName].Get>(); -} - -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h b/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h deleted file mode 100644 index 58da9feaf..000000000 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/core/base_task_api.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -// Performs segmentation on images. -// -// The API expects a TFLite model with mandatory TFLite Model Metadata. -// -// Input tensor: -// (kTfLiteUInt8/kTfLiteFloat32) -// - image input of size `[batch x height x width x channels]`. -// - batch inference is not supported (`batch` is required to be 1). -// - RGB and greyscale inputs are supported (`channels` is required to be -// 1 or 3). -// - if type is kTfLiteFloat32, NormalizationOptions are required to be -// attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. -// - batch is always 1 -// An example of such model can be found at: -// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 -class ImageSegmenter : core::BaseTaskApi { - public: - using BaseTaskApi::BaseTaskApi; - - // Creates a Segmenter from the provided options. A non-default - // OpResolver can be specified in order to support custom Ops or specify a - // subset of built-in Ops. - static absl::StatusOr> Create( - std::unique_ptr options, - std::unique_ptr resolver = - absl::make_unique()); - - // Runs the actual segmentation task. - absl::StatusOr> Segment(mediapipe::Image image); -}; - -} // namespace vision -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 2d13eab9c..b52604c2b 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -80,6 +80,7 @@ filegroup( ], ) +# TODO Create individual filegroup for models required for each Tasks. filegroup( name = "test_models", srcs = [