Project import generated by Copybara.
GitOrigin-RevId: ca9878d6e9c5beb87512e7536b200c55d150ede8
This commit is contained in:
parent
ebec590cfe
commit
06c30f1931
|
@ -51,8 +51,7 @@ RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /u
|
||||||
RUN pip3 install --upgrade setuptools
|
RUN pip3 install --upgrade setuptools
|
||||||
RUN pip3 install wheel
|
RUN pip3 install wheel
|
||||||
RUN pip3 install future
|
RUN pip3 install future
|
||||||
RUN pip3 install absl-py
|
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
||||||
RUN pip3 install numpy
|
|
||||||
RUN pip3 install six==1.14.0
|
RUN pip3 install six==1.14.0
|
||||||
RUN pip3 install tensorflow==2.2.0
|
RUN pip3 install tensorflow==2.2.0
|
||||||
RUN pip3 install tf_slim
|
RUN pip3 install tf_slim
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.2 KiB After Width: | Height: | Size: 2.6 KiB |
Binary file not shown.
After Width: | Height: | Size: 1.2 KiB |
|
@ -84,6 +84,7 @@
|
||||||
{
|
{
|
||||||
"idiom" : "ipad",
|
"idiom" : "ipad",
|
||||||
"size" : "76x76",
|
"size" : "76x76",
|
||||||
|
"filename" : "76_c_Ipad_2x.png",
|
||||||
"scale" : "2x"
|
"scale" : "2x"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -25,7 +25,7 @@ message AudioClassifierOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional AudioClassifierOptions ext = 451755788;
|
optional AudioClassifierOptions ext = 451755788;
|
||||||
}
|
}
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
|
|
@ -43,3 +43,73 @@ cc_library(
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "score_calibration_calculator_proto",
|
||||||
|
srcs = ["score_calibration_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "score_calibration_calculator",
|
||||||
|
srcs = ["score_calibration_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":score_calibration_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "score_calibration_calculator_test",
|
||||||
|
srcs = ["score_calibration_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":score_calibration_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "score_calibration_utils",
|
||||||
|
srcs = ["score_calibration_utils.cc"],
|
||||||
|
hdrs = ["score_calibration_utils.h"],
|
||||||
|
deps = [
|
||||||
|
":score_calibration_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "score_calibration_utils_test",
|
||||||
|
srcs = ["score_calibration_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":score_calibration_calculator_cc_proto",
|
||||||
|
":score_calibration_utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,259 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::absl::StatusCode;
|
||||||
|
using ::mediapipe::tasks::CreateStatusWithPayload;
|
||||||
|
using ::mediapipe::tasks::MediaPipeTasksStatus;
|
||||||
|
using ::mediapipe::tasks::ScoreCalibrationCalculatorOptions;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Used to prevent log(<=0.0) in ClampedLog() calls.
|
||||||
|
constexpr float kLogScoreMinimum = 1e-16;
|
||||||
|
|
||||||
|
// Returns the following, depending on x:
|
||||||
|
// x => threshold: log(x)
|
||||||
|
// x < threshold: 2 * log(thresh) - log(2 * thresh - x)
|
||||||
|
// This form (a) is anti-symmetric about the threshold and (b) has continuous
|
||||||
|
// value and first derivative. This is done to prevent taking the log of values
|
||||||
|
// close to 0 which can lead to floating point errors and is better than simple
|
||||||
|
// clamping since it preserves order for scores less than the threshold.
|
||||||
|
float ClampedLog(float x, float threshold) {
|
||||||
|
if (x < threshold) {
|
||||||
|
return 2.0 * std::log(static_cast<double>(threshold)) -
|
||||||
|
log(2.0 * threshold - x);
|
||||||
|
}
|
||||||
|
return std::log(static_cast<double>(x));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Applies score calibration to a tensor of score predictions, typically applied
|
||||||
|
// to the output of a classification or object detection model.
|
||||||
|
//
|
||||||
|
// See corresponding options for more details on the score calibration
|
||||||
|
// parameters and formula.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// SCORES - std::vector<Tensor>
|
||||||
|
// A vector containing a single Tensor `x` of type kFloat32, representing
|
||||||
|
// the scores to calibrate. By default (i.e. if INDICES is not connected),
|
||||||
|
// x[i] will be calibrated using the sigmoid provided at index i in the
|
||||||
|
// options.
|
||||||
|
// INDICES - std::vector<Tensor> @Optional
|
||||||
|
// An optional vector containing a single Tensor `y` of type kFloat32 and
|
||||||
|
// same size as `x`. If provided, x[i] will be calibrated using the sigmoid
|
||||||
|
// provided at index y[i] (casted as an integer) in the options. `x` and `y`
|
||||||
|
// must contain the same number of elements. Typically used for object
|
||||||
|
// detection models.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// CALIBRATED_SCORES - std::vector<Tensor>
|
||||||
|
// A vector containing a single Tensor of type kFloat32 and of the same size
|
||||||
|
// as the input tensors. Contains the output calibrated scores.
|
||||||
|
class ScoreCalibrationCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::vector<Tensor>> kScoresIn{"SCORES"};
|
||||||
|
static constexpr Input<std::vector<Tensor>>::Optional kIndicesIn{"INDICES"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kScoresOut{"CALIBRATED_SCORES"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kScoresIn, kIndicesIn, kScoresOut);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ScoreCalibrationCalculatorOptions options_;
|
||||||
|
std::function<float(float)> score_transformation_;
|
||||||
|
|
||||||
|
// Computes the calibrated score for the provided index. Does not check for
|
||||||
|
// out-of-bounds index.
|
||||||
|
float ComputeCalibratedScore(int index, float score);
|
||||||
|
// Same as above, but does check for out-of-bounds index.
|
||||||
|
absl::StatusOr<float> SafeComputeCalibratedScore(int index, float score);
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status ScoreCalibrationCalculator::Open(CalculatorContext* cc) {
|
||||||
|
options_ = cc->Options<ScoreCalibrationCalculatorOptions>();
|
||||||
|
// Sanity checks.
|
||||||
|
if (options_.sigmoids_size() == 0) {
|
||||||
|
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
|
||||||
|
"Expected at least one sigmoid, found none.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
for (const auto& sigmoid : options_.sigmoids()) {
|
||||||
|
if (sigmoid.has_scale() && sigmoid.scale() < 0.0) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("The scale parameter of the sigmoids must be "
|
||||||
|
"positive, found %f.",
|
||||||
|
sigmoid.scale()),
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Set score transformation function once and for all.
|
||||||
|
switch (options_.score_transformation()) {
|
||||||
|
case tasks::ScoreCalibrationCalculatorOptions::IDENTITY:
|
||||||
|
score_transformation_ = [](float x) { return x; };
|
||||||
|
break;
|
||||||
|
case tasks::ScoreCalibrationCalculatorOptions::LOG:
|
||||||
|
score_transformation_ = [](float x) {
|
||||||
|
return ClampedLog(x, kLogScoreMinimum);
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case tasks::ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC:
|
||||||
|
score_transformation_ = [](float x) {
|
||||||
|
return (ClampedLog(x, kLogScoreMinimum) -
|
||||||
|
ClampedLog(1.0 - x, kLogScoreMinimum));
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Unsupported ScoreTransformation type: %s",
|
||||||
|
ScoreCalibrationCalculatorOptions::ScoreTransformation_Name(
|
||||||
|
options_.score_transformation())),
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status ScoreCalibrationCalculator::Process(CalculatorContext* cc) {
|
||||||
|
RET_CHECK_EQ(kScoresIn(cc)->size(), 1);
|
||||||
|
const auto& scores = (*kScoresIn(cc))[0];
|
||||||
|
RET_CHECK(scores.element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
auto scores_view = scores.GetCpuReadView();
|
||||||
|
const float* raw_scores = scores_view.buffer<float>();
|
||||||
|
int num_scores = scores.shape().num_elements();
|
||||||
|
|
||||||
|
auto output_tensors = std::make_unique<std::vector<Tensor>>();
|
||||||
|
output_tensors->reserve(1);
|
||||||
|
output_tensors->emplace_back(scores.element_type(), scores.shape());
|
||||||
|
auto calibrated_scores = &output_tensors->back();
|
||||||
|
auto calibrated_scores_view = calibrated_scores->GetCpuWriteView();
|
||||||
|
float* raw_calibrated_scores = calibrated_scores_view.buffer<float>();
|
||||||
|
|
||||||
|
if (kIndicesIn(cc).IsConnected()) {
|
||||||
|
RET_CHECK_EQ(kIndicesIn(cc)->size(), 1);
|
||||||
|
const auto& indices = (*kIndicesIn(cc))[0];
|
||||||
|
RET_CHECK(indices.element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
if (num_scores != indices.shape().num_elements()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Mismatch between number of elements in the input "
|
||||||
|
"scores tensor (%d) and indices tensor (%d).",
|
||||||
|
num_scores, indices.shape().num_elements()),
|
||||||
|
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||||
|
}
|
||||||
|
auto indices_view = indices.GetCpuReadView();
|
||||||
|
const float* raw_indices = indices_view.buffer<float>();
|
||||||
|
for (int i = 0; i < num_scores; ++i) {
|
||||||
|
// Use the "safe" flavor as we need to check that the externally provided
|
||||||
|
// indices are not out-of-bounds.
|
||||||
|
ASSIGN_OR_RETURN(raw_calibrated_scores[i],
|
||||||
|
SafeComputeCalibratedScore(
|
||||||
|
static_cast<int>(raw_indices[i]), raw_scores[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (num_scores != options_.sigmoids_size()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Mismatch between number of sigmoids (%d) and number "
|
||||||
|
"of elements in the input scores tensor (%d).",
|
||||||
|
options_.sigmoids_size(), num_scores),
|
||||||
|
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < num_scores; ++i) {
|
||||||
|
// Use the "unsafe" flavor as we have already checked for out-of-bounds
|
||||||
|
// issues.
|
||||||
|
raw_calibrated_scores[i] = ComputeCalibratedScore(i, raw_scores[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kScoresOut(cc).Send(std::move(output_tensors));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
float ScoreCalibrationCalculator::ComputeCalibratedScore(int index,
|
||||||
|
float score) {
|
||||||
|
const auto& sigmoid = options_.sigmoids(index);
|
||||||
|
|
||||||
|
bool is_empty =
|
||||||
|
!sigmoid.has_scale() || !sigmoid.has_offset() || !sigmoid.has_slope();
|
||||||
|
bool is_below_min_score =
|
||||||
|
sigmoid.has_min_score() && score < sigmoid.min_score();
|
||||||
|
if (is_empty || is_below_min_score) {
|
||||||
|
return options_.default_score();
|
||||||
|
}
|
||||||
|
|
||||||
|
float transformed_score = score_transformation_(score);
|
||||||
|
float scale_shifted_score =
|
||||||
|
transformed_score * sigmoid.slope() + sigmoid.offset();
|
||||||
|
// For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0
|
||||||
|
// and exp(x) / (1+exp(x)) when scale_shifted_score < 0.
|
||||||
|
float calibrated_score;
|
||||||
|
if (scale_shifted_score >= 0.0) {
|
||||||
|
calibrated_score =
|
||||||
|
sigmoid.scale() /
|
||||||
|
(1.0 + std::exp(static_cast<double>(-scale_shifted_score)));
|
||||||
|
} else {
|
||||||
|
float score_exp = std::exp(static_cast<double>(scale_shifted_score));
|
||||||
|
calibrated_score = sigmoid.scale() * score_exp / (1.0 + score_exp);
|
||||||
|
}
|
||||||
|
// Scale is non-negative (checked in SigmoidFromLabelAndLine),
|
||||||
|
// thus calibrated_score should be in the range of [0, scale]. However, due to
|
||||||
|
// numberical stability issue, it may fall out of the boundary. Cap the value
|
||||||
|
// to [0, scale] instead.
|
||||||
|
return std::max(std::min(calibrated_score, sigmoid.scale()), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<float> ScoreCalibrationCalculator::SafeComputeCalibratedScore(
|
||||||
|
int index, float score) {
|
||||||
|
if (index < 0) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Expected positive indices, found %d.", index),
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
if (index > options_.sigmoids_size()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Unable to get score calibration parameters for index "
|
||||||
|
"%d : only %d sigmoids were provided.",
|
||||||
|
index, options_.sigmoids_size()),
|
||||||
|
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||||
|
}
|
||||||
|
return ComputeCalibratedScore(index, score);
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ScoreCalibrationCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,67 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe.tasks;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message ScoreCalibrationCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional ScoreCalibrationCalculatorOptions ext = 470204318;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Score calibration parameters for one individual category. The formula used
|
||||||
|
// to transform the uncalibrated score `x` is:
|
||||||
|
// * `f(x) = scale / (1 + e^-(slope * g(x) + offset))` if `x > min_score` or
|
||||||
|
// if no min_score has been specified,
|
||||||
|
// * `f(x) = default_score` otherwise or if no scale, slope or offset have
|
||||||
|
// been specified.
|
||||||
|
//
|
||||||
|
// Where:
|
||||||
|
// * scale must be positive,
|
||||||
|
// * g(x) is a global (i.e. category independent) transform specified using
|
||||||
|
// the score_transformation field,
|
||||||
|
// * default_score is a global parameter defined below.
|
||||||
|
//
|
||||||
|
// There should be exactly one sigmoid per number of supported output
|
||||||
|
// categories in the model, with either:
|
||||||
|
// * no fields set,
|
||||||
|
// * scale, slope and offset set,
|
||||||
|
// * all fields set.
|
||||||
|
message Sigmoid {
|
||||||
|
optional float scale = 1;
|
||||||
|
optional float slope = 2;
|
||||||
|
optional float offset = 3;
|
||||||
|
optional float min_score = 4;
|
||||||
|
}
|
||||||
|
repeated Sigmoid sigmoids = 1;
|
||||||
|
|
||||||
|
// Score transformation that defines the `g(x)` function in the above formula.
|
||||||
|
enum ScoreTransformation {
|
||||||
|
UNSPECIFIED = 0;
|
||||||
|
// g(x) = x.
|
||||||
|
IDENTITY = 1;
|
||||||
|
// g(x) = log(x).
|
||||||
|
LOG = 2;
|
||||||
|
// g(x) = log(x) - log(1 - x).
|
||||||
|
INVERSE_LOGISTIC = 3;
|
||||||
|
}
|
||||||
|
optional ScoreTransformation score_transformation = 2 [default = IDENTITY];
|
||||||
|
|
||||||
|
// Default score.
|
||||||
|
optional float default_score = 3;
|
||||||
|
}
|
|
@ -0,0 +1,309 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::ParseTextProtoOrDie;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::TestParamInfo;
|
||||||
|
using ::testing::TestWithParam;
|
||||||
|
using ::testing::Values;
|
||||||
|
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||||
|
|
||||||
|
// Builds the graph and feeds inputs.
|
||||||
|
void BuildGraph(CalculatorRunner* runner, std::vector<float> scores,
|
||||||
|
std::optional<std::vector<int>> indices = std::nullopt) {
|
||||||
|
auto scores_tensors = std::make_unique<std::vector<Tensor>>();
|
||||||
|
scores_tensors->emplace_back(
|
||||||
|
Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{1, static_cast<int>(scores.size())});
|
||||||
|
auto scores_view = scores_tensors->back().GetCpuWriteView();
|
||||||
|
float* scores_buffer = scores_view.buffer<float>();
|
||||||
|
ASSERT_NE(scores_buffer, nullptr);
|
||||||
|
for (int i = 0; i < scores.size(); ++i) {
|
||||||
|
scores_buffer[i] = scores[i];
|
||||||
|
}
|
||||||
|
auto& input_scores_packets = runner->MutableInputs()->Tag("SCORES").packets;
|
||||||
|
input_scores_packets.push_back(
|
||||||
|
mediapipe::Adopt(scores_tensors.release()).At(mediapipe::Timestamp(0)));
|
||||||
|
|
||||||
|
if (indices.has_value()) {
|
||||||
|
auto indices_tensors = std::make_unique<std::vector<Tensor>>();
|
||||||
|
indices_tensors->emplace_back(
|
||||||
|
Tensor::ElementType::kFloat32,
|
||||||
|
Tensor::Shape{1, static_cast<int>(indices->size())});
|
||||||
|
auto indices_view = indices_tensors->back().GetCpuWriteView();
|
||||||
|
float* indices_buffer = indices_view.buffer<float>();
|
||||||
|
ASSERT_NE(indices_buffer, nullptr);
|
||||||
|
for (int i = 0; i < indices->size(); ++i) {
|
||||||
|
indices_buffer[i] = static_cast<float>((*indices)[i]);
|
||||||
|
}
|
||||||
|
auto& input_indices_packets =
|
||||||
|
runner->MutableInputs()->Tag("INDICES").packets;
|
||||||
|
input_indices_packets.push_back(mediapipe::Adopt(indices_tensors.release())
|
||||||
|
.At(mediapipe::Timestamp(0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compares the provided tensor contents with the expected values.
|
||||||
|
void ValidateResult(const Tensor& actual, const std::vector<float>& expected) {
|
||||||
|
EXPECT_EQ(actual.element_type(), Tensor::ElementType::kFloat32);
|
||||||
|
EXPECT_EQ(expected.size(), actual.shape().num_elements());
|
||||||
|
auto view = actual.GetCpuReadView();
|
||||||
|
auto buffer = view.buffer<float>();
|
||||||
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(expected[i], buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithNoSigmoid) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.5, 0.5, 0.5});
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Expected at least one sigmoid, found none"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeScale) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { slope: 1 offset: 1 scale: -1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.5, 0.5, 0.5});
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.message(),
|
||||||
|
HasSubstr("The scale parameter of the sigmoids must be positive"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithUnspecifiedTransformation) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { slope: 1 offset: 1 scale: 1 }
|
||||||
|
score_transformation: UNSPECIFIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.5, 0.5, 0.5});
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Unsupported ScoreTransformation type"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct holding the parameters for parameterized tests below.
|
||||||
|
struct CalibrationTestParams {
|
||||||
|
// The score transformation to apply.
|
||||||
|
std::string score_transformation;
|
||||||
|
// The expected results.
|
||||||
|
std::vector<float> expected_results;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CalibrationWithoutIndicesTest
|
||||||
|
: public TestWithParam<CalibrationTestParams> {};
|
||||||
|
|
||||||
|
TEST_P(CalibrationWithoutIndicesTest, Succeeds) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(absl::StrFormat(
|
||||||
|
R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 }
|
||||||
|
sigmoids {}
|
||||||
|
score_transformation: %s
|
||||||
|
default_score: 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb",
|
||||||
|
GetParam().score_transformation)));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5});
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const Tensor& results = runner.Outputs()
|
||||||
|
.Get("CALIBRATED_SCORES", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<std::vector<Tensor>>()[0];
|
||||||
|
|
||||||
|
ValidateResult(results, GetParam().expected_results);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ScoreCalibrationCalculatorTest, CalibrationWithoutIndicesTest,
|
||||||
|
Values(CalibrationTestParams{.score_transformation = "IDENTITY",
|
||||||
|
.expected_results = {0.4948505976,
|
||||||
|
0.5059588508, 0.2, 0.2}},
|
||||||
|
CalibrationTestParams{
|
||||||
|
.score_transformation = "LOG",
|
||||||
|
.expected_results = {0.2976901255, 0.3393665735, 0.2, 0.2}},
|
||||||
|
CalibrationTestParams{
|
||||||
|
.score_transformation = "INVERSE_LOGISTIC",
|
||||||
|
.expected_results = {0.3203217641, 0.3778080605, 0.2, 0.2}}),
|
||||||
|
[](const TestParamInfo<CalibrationWithoutIndicesTest::ParamType>& info) {
|
||||||
|
return info.param.score_transformation;
|
||||||
|
});
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithMissingSigmoids) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 }
|
||||||
|
sigmoids {}
|
||||||
|
score_transformation: LOG
|
||||||
|
default_score: 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5, 0.6});
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Mismatch between number of sigmoids"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, SucceedsWithIndices) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
input_stream: "INDICES:indices"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 }
|
||||||
|
sigmoids {}
|
||||||
|
score_transformation: IDENTITY
|
||||||
|
default_score: 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
std::vector<int> indices = {1, 2, 3, 0};
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.3, 0.4, 0.5, 0.2}, indices);
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const Tensor& results = runner.Outputs()
|
||||||
|
.Get("CALIBRATED_SCORES", 0)
|
||||||
|
.packets[0]
|
||||||
|
.Get<std::vector<Tensor>>()[0];
|
||||||
|
ValidateResult(results, {0.5059588508, 0.2, 0.2, 0.4948505976});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeIndex) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
input_stream: "INDICES:indices"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 }
|
||||||
|
sigmoids {}
|
||||||
|
score_transformation: IDENTITY
|
||||||
|
default_score: 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
std::vector<int> indices = {0, 1, 2, -1};
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices);
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(), HasSubstr("Expected positive indices"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ScoreCalibrationCalculatorTest, FailsWithOutOfBoundsIndex) {
|
||||||
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
calculator: "ScoreCalibrationCalculator"
|
||||||
|
input_stream: "SCORES:scores"
|
||||||
|
input_stream: "INDICES:indices"
|
||||||
|
output_stream: "CALIBRATED_SCORES:calibrated_scores"
|
||||||
|
options {
|
||||||
|
[mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 }
|
||||||
|
sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 }
|
||||||
|
sigmoids {}
|
||||||
|
score_transformation: IDENTITY
|
||||||
|
default_score: 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
std::vector<int> indices = {0, 1, 5, 3};
|
||||||
|
|
||||||
|
BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices);
|
||||||
|
auto status = runner.Run();
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.message(),
|
||||||
|
HasSubstr("Unable to get score calibration parameters for index"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,115 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/numbers.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Converts ScoreTransformation type from TFLite Metadata to calculator options.
|
||||||
|
ScoreCalibrationCalculatorOptions::ScoreTransformation
|
||||||
|
ConvertScoreTransformationType(tflite::ScoreTransformationType type) {
|
||||||
|
switch (type) {
|
||||||
|
case tflite::ScoreTransformationType_IDENTITY:
|
||||||
|
return ScoreCalibrationCalculatorOptions::IDENTITY;
|
||||||
|
case tflite::ScoreTransformationType_LOG:
|
||||||
|
return ScoreCalibrationCalculatorOptions::LOG;
|
||||||
|
case tflite::ScoreTransformationType_INVERSE_LOGISTIC:
|
||||||
|
return ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parses a single line of the score calibration file into the provided sigmoid.
|
||||||
|
absl::Status FillSigmoidFromLine(
|
||||||
|
absl::string_view line,
|
||||||
|
ScoreCalibrationCalculatorOptions::Sigmoid* sigmoid) {
|
||||||
|
if (line.empty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
std::vector<absl::string_view> str_params = absl::StrSplit(line, ',');
|
||||||
|
if (str_params.size() != 3 && str_params.size() != 4) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Expected 3 or 4 parameters per line in score "
|
||||||
|
"calibration file, got %d.",
|
||||||
|
str_params.size()),
|
||||||
|
MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError);
|
||||||
|
}
|
||||||
|
std::vector<float> params(str_params.size());
|
||||||
|
for (int i = 0; i < str_params.size(); ++i) {
|
||||||
|
if (!absl::SimpleAtof(str_params[i], ¶ms[i])) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Could not parse score calibration parameter as float: %s.",
|
||||||
|
str_params[i]),
|
||||||
|
MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (params[0] < 0) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"The scale parameter of the sigmoids must be positive, found %f.",
|
||||||
|
params[0]),
|
||||||
|
MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError);
|
||||||
|
}
|
||||||
|
sigmoid->set_scale(params[0]);
|
||||||
|
sigmoid->set_slope(params[1]);
|
||||||
|
sigmoid->set_offset(params[2]);
|
||||||
|
if (params.size() == 4) {
|
||||||
|
sigmoid->set_min_score(params[3]);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::Status ConfigureScoreCalibration(
|
||||||
|
tflite::ScoreTransformationType score_transformation, float default_score,
|
||||||
|
absl::string_view score_calibration_file,
|
||||||
|
ScoreCalibrationCalculatorOptions* calculator_options) {
|
||||||
|
calculator_options->set_score_transformation(
|
||||||
|
ConvertScoreTransformationType(score_transformation));
|
||||||
|
calculator_options->set_default_score(default_score);
|
||||||
|
|
||||||
|
if (score_calibration_file.empty()) {
|
||||||
|
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||||
|
"Expected non-empty score calibration file.",
|
||||||
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
|
}
|
||||||
|
std::vector<absl::string_view> lines =
|
||||||
|
absl::StrSplit(score_calibration_file, '\n');
|
||||||
|
for (const auto& line : lines) {
|
||||||
|
auto* sigmoid = calculator_options->add_sigmoids();
|
||||||
|
MP_RETURN_IF_ERROR(FillSigmoidFromLine(line, sigmoid));
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,38 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
|
||||||
|
// Populates ScoreCalibrationCalculatorOptions given a TFLite Metadata score
|
||||||
|
// transformation type, default threshold and score calibration AssociatedFile
|
||||||
|
// contents, as specified in `TENSOR_AXIS_SCORE_CALIBRATION`:
|
||||||
|
// https://github.com/google/mediapipe/blob/master/mediapipe/tasks/metadata/metadata_schema.fbs
|
||||||
|
absl::Status ConfigureScoreCalibration(
|
||||||
|
tflite::ScoreTransformationType score_transformation, float default_score,
|
||||||
|
absl::string_view score_calibration_file,
|
||||||
|
ScoreCalibrationCalculatorOptions* options);
|
||||||
|
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_
|
|
@ -0,0 +1,130 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, SucceedsWithoutTrailingNewline) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
std::string score_calibration_file =
|
||||||
|
absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7");
|
||||||
|
|
||||||
|
MP_ASSERT_OK(ConfigureScoreCalibration(
|
||||||
|
tflite::ScoreTransformationType_IDENTITY,
|
||||||
|
/*default_score=*/0.5, score_calibration_file, &options));
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
options,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<ScoreCalibrationCalculatorOptions>(R"pb(
|
||||||
|
score_transformation: IDENTITY
|
||||||
|
default_score: 0.5
|
||||||
|
sigmoids {}
|
||||||
|
sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 }
|
||||||
|
sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 }
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, SucceedsWithTrailingNewline) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
std::string score_calibration_file =
|
||||||
|
absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7\n");
|
||||||
|
|
||||||
|
MP_ASSERT_OK(ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG,
|
||||||
|
/*default_score=*/0.5,
|
||||||
|
score_calibration_file, &options));
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
options,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<ScoreCalibrationCalculatorOptions>(R"pb(
|
||||||
|
score_transformation: LOG
|
||||||
|
default_score: 0.5
|
||||||
|
sigmoids {}
|
||||||
|
sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 }
|
||||||
|
sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 }
|
||||||
|
sigmoids {}
|
||||||
|
)pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, FailsWithEmptyFile) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
|
||||||
|
auto status =
|
||||||
|
ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG,
|
||||||
|
/*default_score=*/0.5,
|
||||||
|
/*score_calibration_file=*/"", &options);
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Expected non-empty score calibration file"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, FailsWithInvalidNumParameters) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
std::string score_calibration_file = absl::StrCat("0.1,0.2,0.3\n", "0.1,0.2");
|
||||||
|
|
||||||
|
auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG,
|
||||||
|
/*default_score=*/0.5,
|
||||||
|
score_calibration_file, &options);
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(status.message(),
|
||||||
|
HasSubstr("Expected 3 or 4 parameters per line"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, FailsWithNonParseableParameter) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
std::string score_calibration_file =
|
||||||
|
absl::StrCat("0.1,0.2,0.3\n", "0.1,foo,0.3\n");
|
||||||
|
|
||||||
|
auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG,
|
||||||
|
/*default_score=*/0.5,
|
||||||
|
score_calibration_file, &options);
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.message(),
|
||||||
|
HasSubstr("Could not parse score calibration parameter as float"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConfigureScoreCalibrationTest, FailsWithNegativeScaleParameter) {
|
||||||
|
ScoreCalibrationCalculatorOptions options;
|
||||||
|
std::string score_calibration_file =
|
||||||
|
absl::StrCat("0.1,0.2,0.3\n", "-0.1,0.2,0.3\n");
|
||||||
|
|
||||||
|
auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG,
|
||||||
|
/*default_score=*/0.5,
|
||||||
|
score_calibration_file, &options);
|
||||||
|
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.message(),
|
||||||
|
HasSubstr("The scale parameter of the sigmoids must be positive"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
|
@ -18,8 +18,18 @@ limitations under the License.
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#ifdef ABSL_HAVE_MMAP
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <direct.h>
|
||||||
|
#include <io.h>
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -44,12 +54,17 @@ using ::absl::StatusCode;
|
||||||
// file descriptor correctly, as according to mmap(2), the offset used in mmap
|
// file descriptor correctly, as according to mmap(2), the offset used in mmap
|
||||||
// must be a multiple of sysconf(_SC_PAGE_SIZE).
|
// must be a multiple of sysconf(_SC_PAGE_SIZE).
|
||||||
int64 GetPageSizeAlignedOffset(int64 offset) {
|
int64 GetPageSizeAlignedOffset(int64 offset) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
// mmap is not used on Windows
|
||||||
|
return -1;
|
||||||
|
#else
|
||||||
int64 aligned_offset = offset;
|
int64 aligned_offset = offset;
|
||||||
int64 page_size = sysconf(_SC_PAGE_SIZE);
|
int64 page_size = sysconf(_SC_PAGE_SIZE);
|
||||||
if (offset % page_size != 0) {
|
if (offset % page_size != 0) {
|
||||||
aligned_offset = offset / page_size * page_size;
|
aligned_offset = offset / page_size * page_size;
|
||||||
}
|
}
|
||||||
return aligned_offset;
|
return aligned_offset;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -69,6 +84,12 @@ ExternalFileHandler::CreateFromExternalFile(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ExternalFileHandler::MapExternalFile() {
|
absl::Status ExternalFileHandler::MapExternalFile() {
|
||||||
|
// TODO: Add Windows support
|
||||||
|
#ifdef _WIN32
|
||||||
|
return CreateStatusWithPayload(StatusCode::kFailedPrecondition,
|
||||||
|
"File loading is not yet supported on Windows",
|
||||||
|
MediaPipeTasksStatus::kFileReadError);
|
||||||
|
#else
|
||||||
if (!external_file_.file_content().empty()) {
|
if (!external_file_.file_content().empty()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -169,6 +190,7 @@ absl::Status ExternalFileHandler::MapExternalFile() {
|
||||||
MediaPipeTasksStatus::kFileMmapError);
|
MediaPipeTasksStatus::kFileMmapError);
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::string_view ExternalFileHandler::GetFileContent() {
|
absl::string_view ExternalFileHandler::GetFileContent() {
|
||||||
|
@ -182,9 +204,11 @@ absl::string_view ExternalFileHandler::GetFileContent() {
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalFileHandler::~ExternalFileHandler() {
|
ExternalFileHandler::~ExternalFileHandler() {
|
||||||
|
#ifndef _WIN32
|
||||||
if (buffer_ != MAP_FAILED) {
|
if (buffer_ != MAP_FAILED) {
|
||||||
munmap(buffer_, buffer_aligned_size_);
|
munmap(buffer_, buffer_aligned_size_);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
if (owned_fd_ >= 0) {
|
if (owned_fd_ >= 0) {
|
||||||
close(owned_fd_);
|
close(owned_fd_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
// TODO Refactor naming and class structure of hand related Tasks.
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
||||||
|
|
|
@ -24,7 +24,7 @@ message HandLandmarkDetectorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional HandLandmarkDetectorOptions ext = 462713202;
|
optional HandLandmarkDetectorOptions ext = 462713202;
|
||||||
}
|
}
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ message ImageClassifierOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ImageClassifierOptions ext = 456383383;
|
optional ImageClassifierOptions ext = 456383383;
|
||||||
}
|
}
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
|
|
@ -12,34 +12,25 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "image_segmenter_options_proto",
|
|
||||||
srcs = ["image_segmenter_options.proto"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
|
||||||
"//mediapipe/framework:calculator_proto",
|
|
||||||
"//mediapipe/tasks/cc/components:segmenter_options_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "image_segmenter",
|
name = "image_segmenter",
|
||||||
srcs = ["image_segmenter.cc"],
|
srcs = ["image_segmenter.cc"],
|
||||||
hdrs = ["image_segmenter.h"],
|
hdrs = ["image_segmenter.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":image_segmenter_graph",
|
":image_segmenter_graph",
|
||||||
":image_segmenter_options_cc_proto",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/tasks/cc/core:base_task_api",
|
"//mediapipe/tasks/cc/components:segmenter_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
|
@ -51,7 +42,6 @@ cc_library(
|
||||||
name = "image_segmenter_graph",
|
name = "image_segmenter_graph",
|
||||||
srcs = ["image_segmenter_graph.cc"],
|
srcs = ["image_segmenter_graph.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":image_segmenter_options_cc_proto",
|
|
||||||
"//mediapipe/calculators/core:merge_to_vector_calculator",
|
"//mediapipe/calculators/core:merge_to_vector_calculator",
|
||||||
"//mediapipe/calculators/image:image_properties_calculator",
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
|
@ -70,6 +60,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
"//mediapipe/util:label_map_util",
|
"//mediapipe/util:label_map_util",
|
||||||
|
@ -82,9 +73,9 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "custom_op_resolvers",
|
name = "image_segmenter_op_resolvers",
|
||||||
srcs = ["custom_op_resolvers.cc"],
|
srcs = ["image_segmenter_op_resolvers.cc"],
|
||||||
hdrs = ["custom_op_resolvers.h"],
|
hdrs = ["image_segmenter_op_resolvers.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
134
mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc
Normal file
134
mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/segmenter_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||||
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
||||||
|
|
||||||
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
|
using ::mediapipe::Image;
|
||||||
|
using ImageSegmenterOptionsProto =
|
||||||
|
image_segmenter::proto::ImageSegmenterOptions;
|
||||||
|
|
||||||
|
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||||
|
// "mediapipe.tasks.vision.ImageSegmenterGraph".
|
||||||
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
|
std::unique_ptr<ImageSegmenterOptionsProto> options,
|
||||||
|
bool enable_flow_limiting) {
|
||||||
|
api2::builder::Graph graph;
|
||||||
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
|
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
||||||
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
|
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||||
|
graph.Out(kGroupedSegmentationTag);
|
||||||
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
|
graph.Out(kImageTag);
|
||||||
|
if (enable_flow_limiting) {
|
||||||
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
|
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
|
||||||
|
}
|
||||||
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
return graph.GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the user-facing ImageSegmenterOptions struct to the internal
|
||||||
|
// ImageSegmenterOptions proto.
|
||||||
|
std::unique_ptr<ImageSegmenterOptionsProto> ConvertImageSegmenterOptionsToProto(
|
||||||
|
ImageSegmenterOptions* options) {
|
||||||
|
auto options_proto = std::make_unique<ImageSegmenterOptionsProto>();
|
||||||
|
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||||
|
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||||
|
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||||
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
|
options->running_mode != core::RunningMode::IMAGE);
|
||||||
|
options_proto->set_display_names_locale(options->display_names_locale);
|
||||||
|
switch (options->output_type) {
|
||||||
|
case ImageSegmenterOptions::OutputType::CATEGORY_MASK:
|
||||||
|
options_proto->mutable_segmenter_options()->set_output_type(
|
||||||
|
SegmenterOptions::CATEGORY_MASK);
|
||||||
|
break;
|
||||||
|
case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
||||||
|
options_proto->mutable_segmenter_options()->set_output_type(
|
||||||
|
SegmenterOptions::CONFIDENCE_MASK);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
switch (options->activation) {
|
||||||
|
case ImageSegmenterOptions::Activation::NONE:
|
||||||
|
options_proto->mutable_segmenter_options()->set_activation(
|
||||||
|
SegmenterOptions::NONE);
|
||||||
|
break;
|
||||||
|
case ImageSegmenterOptions::Activation::SIGMOID:
|
||||||
|
options_proto->mutable_segmenter_options()->set_activation(
|
||||||
|
SegmenterOptions::SIGMOID);
|
||||||
|
break;
|
||||||
|
case ImageSegmenterOptions::Activation::SOFTMAX:
|
||||||
|
options_proto->mutable_segmenter_options()->set_activation(
|
||||||
|
SegmenterOptions::SOFTMAX);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return options_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
|
std::unique_ptr<ImageSegmenterOptions> options) {
|
||||||
|
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||||
|
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||||
|
return core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||||
|
ImageSegmenterOptionsProto>(
|
||||||
|
CreateGraphConfig(
|
||||||
|
std::move(options_proto),
|
||||||
|
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||||
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
|
std::move(packets_callback));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||||
|
mediapipe::Image image) {
|
||||||
|
if (image.UsesGpu()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_packets,
|
||||||
|
ProcessImageData({{kImageInStreamName,
|
||||||
|
mediapipe::MakePacket<Image>(std::move(image))}}));
|
||||||
|
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
123
mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h
Normal file
123
mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace tasks {
|
||||||
|
namespace vision {
|
||||||
|
|
||||||
|
// The options for configuring a mediapipe image segmenter task.
|
||||||
|
struct ImageSegmenterOptions {
|
||||||
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
// file with metadata, accelerator options, op resolver, etc.
|
||||||
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// The running mode of the task. Default to the image mode.
|
||||||
|
// Image segmenter has three running modes:
|
||||||
|
// 1) The image mode for segmenting image on single image inputs.
|
||||||
|
// 2) The video mode for segmenting image on the decoded frames of a video.
|
||||||
|
// 3) The live stream mode for segmenting image on the live stream of input
|
||||||
|
// data, such as from camera. In this mode, the "result_callback" below must
|
||||||
|
// be specified to receive the segmentation results asynchronously.
|
||||||
|
core::RunningMode running_mode = core::RunningMode::IMAGE;
|
||||||
|
|
||||||
|
// The locale to use for display names specified through the TFLite Model
|
||||||
|
// Metadata, if any. Defaults to English.
|
||||||
|
std::string display_names_locale = "en";
|
||||||
|
|
||||||
|
// The output type of segmentation results.
|
||||||
|
enum OutputType {
|
||||||
|
// Gives a single output mask where each pixel represents the class which
|
||||||
|
// the pixel in the original image was predicted to belong to.
|
||||||
|
CATEGORY_MASK = 0,
|
||||||
|
// Gives a list of output masks where, for each mask, each pixel represents
|
||||||
|
// the prediction confidence, usually in the [0, 1] range.
|
||||||
|
CONFIDENCE_MASK = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||||
|
|
||||||
|
// The activation function used on the raw segmentation model output.
|
||||||
|
enum Activation {
|
||||||
|
NONE = 0, // No activation function is used.
|
||||||
|
SIGMOID = 1, // Assumes 1-channel input tensor.
|
||||||
|
SOFTMAX = 2, // Assumes multi-channel input tensor.
|
||||||
|
};
|
||||||
|
|
||||||
|
Activation activation = Activation::NONE;
|
||||||
|
|
||||||
|
// The user-defined result callback for processing live stream data.
|
||||||
|
// The result callback should only be specified when the running mode is set
|
||||||
|
// to RunningMode::LIVE_STREAM.
|
||||||
|
std::function<void(absl::StatusOr<std::vector<mediapipe::Image>>,
|
||||||
|
const Image&, int64)>
|
||||||
|
result_callback = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Performs segmentation on images.
|
||||||
|
//
|
||||||
|
// The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||||
|
//
|
||||||
|
// Input tensor:
|
||||||
|
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||||
|
// - image input of size `[batch x height x width x channels]`.
|
||||||
|
// - batch inference is not supported (`batch` is required to be 1).
|
||||||
|
// - RGB and greyscale inputs are supported (`channels` is required to be
|
||||||
|
// 1 or 3).
|
||||||
|
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||||
|
// attached to the metadata for input normalization.
|
||||||
|
// Output tensors:
|
||||||
|
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||||
|
// - list of segmented masks.
|
||||||
|
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||||
|
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||||
|
// `cahnnels`.
|
||||||
|
// - batch is always 1
|
||||||
|
// An example of such model can be found at:
|
||||||
|
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
||||||
|
class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
|
public:
|
||||||
|
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||||
|
|
||||||
|
// Creates an ImageSegmenter from the provided options. A non-default
|
||||||
|
// OpResolver can be specified in the BaseOptions of ImageSegmenterOptions,
|
||||||
|
// to support custom Ops of the segmentation model.
|
||||||
|
static absl::StatusOr<std::unique_ptr<ImageSegmenter>> Create(
|
||||||
|
std::unique_ptr<ImageSegmenterOptions> options);
|
||||||
|
|
||||||
|
// Runs the actual segmentation task.
|
||||||
|
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace vision
|
||||||
|
} // namespace tasks
|
||||||
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
#include "mediapipe/util/label_map_util.h"
|
#include "mediapipe/util/label_map_util.h"
|
||||||
|
@ -53,6 +53,7 @@ using ::mediapipe::api2::builder::MultiSource;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::SegmenterOptions;
|
using ::mediapipe::tasks::SegmenterOptions;
|
||||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions;
|
||||||
using ::tflite::Tensor;
|
using ::tflite::Tensor;
|
||||||
using ::tflite::TensorMetadata;
|
using ::tflite::TensorMetadata;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||||
|
@ -63,6 +64,14 @@ constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the image segmenter
|
||||||
|
// subgraph.
|
||||||
|
struct ImageSegmenterOutputs {
|
||||||
|
std::vector<Source<Image>> segmented_masks;
|
||||||
|
// The same as the input image, mainly used for live stream mode.
|
||||||
|
Source<Image> image;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) {
|
absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) {
|
||||||
|
@ -140,6 +149,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
||||||
|
|
||||||
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
|
// An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic
|
||||||
// segmentation.
|
// segmentation.
|
||||||
|
// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION.
|
||||||
|
// Users can retrieve segmented mask of only particular category/channel from
|
||||||
|
// SEGMENTATION, and users can also get all segmented masks from
|
||||||
|
// GROUPED_SEGMENTATION.
|
||||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -147,8 +160,13 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
||||||
// Image to perform segmentation on.
|
// Image to perform segmentation on.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// SEGMENTATION - SEGMENTATION
|
// SEGMENTATION - mediapipe::Image @Multiple
|
||||||
// Segmented masks.
|
// Segmented masks for individual category. Segmented mask of single
|
||||||
|
// category can be accessed by index based output stream.
|
||||||
|
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
||||||
|
// The output segmented masks grouped in a vector.
|
||||||
|
// IMAGE - mediapipe::Image
|
||||||
|
// The image that image segmenter runs on.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
|
@ -156,7 +174,8 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
||||||
// input_stream: "IMAGE:image"
|
// input_stream: "IMAGE:image"
|
||||||
// output_stream: "SEGMENTATION:segmented_masks"
|
// output_stream: "SEGMENTATION:segmented_masks"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.ImageSegmenterOptions.ext] {
|
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext]
|
||||||
|
// {
|
||||||
// segmenter_options {
|
// segmenter_options {
|
||||||
// output_type: CONFIDENCE_MASK
|
// output_type: CONFIDENCE_MASK
|
||||||
// activation: SOFTMAX
|
// activation: SOFTMAX
|
||||||
|
@ -171,20 +190,22 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||||
CreateModelResources<ImageSegmenterOptions>(sc));
|
CreateModelResources<ImageSegmenterOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(auto segmentations,
|
ASSIGN_OR_RETURN(auto output_streams,
|
||||||
BuildSegmentationTask(
|
BuildSegmentationTask(
|
||||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||||
graph[Input<Image>(kImageTag)], graph));
|
graph[Input<Image>(kImageTag)], graph));
|
||||||
|
|
||||||
auto& merge_images_to_vector =
|
auto& merge_images_to_vector =
|
||||||
graph.AddNode("MergeImagesToVectorCalculator");
|
graph.AddNode("MergeImagesToVectorCalculator");
|
||||||
for (int i = 0; i < segmentations.size(); ++i) {
|
for (int i = 0; i < output_streams.segmented_masks.size(); ++i) {
|
||||||
segmentations[i] >> merge_images_to_vector[Input<Image>::Multiple("")][i];
|
output_streams.segmented_masks[i] >>
|
||||||
segmentations[i] >> graph[Output<Image>::Multiple(kSegmentationTag)][i];
|
merge_images_to_vector[Input<Image>::Multiple("")][i];
|
||||||
|
output_streams.segmented_masks[i] >>
|
||||||
|
graph[Output<Image>::Multiple(kSegmentationTag)][i];
|
||||||
}
|
}
|
||||||
merge_images_to_vector.Out("") >>
|
merge_images_to_vector.Out("") >>
|
||||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||||
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,12 +214,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
// builder::Graph instance. The segmentation pipeline takes images
|
// builder::Graph instance. The segmentation pipeline takes images
|
||||||
// (mediapipe::Image) as the input and returns segmented image mask as output.
|
// (mediapipe::Image) as the input and returns segmented image mask as output.
|
||||||
//
|
//
|
||||||
// task_options: the mediapipe tasks ImageSegmenterOptions.
|
// task_options: the mediapipe tasks ImageSegmenterOptions proto.
|
||||||
// model_resources: the ModelSources object initialized from a segmentation
|
// model_resources: the ModelSources object initialized from a segmentation
|
||||||
// model file with model metadata.
|
// model file with model metadata.
|
||||||
// image_in: (mediapipe::Image) stream to run segmentation on.
|
// image_in: (mediapipe::Image) stream to run segmentation on.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<std::vector<Source<Image>>> BuildSegmentationTask(
|
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||||
const ImageSegmenterOptions& task_options,
|
const ImageSegmenterOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Graph& graph) {
|
Graph& graph) {
|
||||||
|
@ -246,7 +267,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return segmented_masks;
|
return {{
|
||||||
|
.segmented_masks = segmented_masks,
|
||||||
|
.image = preprocessing[Output<Image>(kImageTag)],
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||||
|
|
||||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_
|
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
@ -34,4 +34,4 @@ class SelfieSegmentationModelOpResolver
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_
|
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -32,8 +32,8 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/components/segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
|
@ -46,11 +46,8 @@ namespace {
|
||||||
|
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::ImageSegmenterOptions;
|
|
||||||
using ::mediapipe::tasks::SegmenterOptions;
|
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
using ::tflite::ops::builtin::BuiltinOpResolver;
|
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
constexpr char kDeeplabV3WithMetadata[] = "deeplabv3.tflite";
|
constexpr char kDeeplabV3WithMetadata[] = "deeplabv3.tflite";
|
||||||
|
@ -167,19 +164,19 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
options->base_options.model_file_name =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata));
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
MP_ASSERT_OK(ImageSegmenter::Create(std::move(options),
|
options->base_options.op_resolver = absl::make_unique<DeepLabOpResolver>();
|
||||||
absl::make_unique<DeepLabOpResolver>()));
|
MP_ASSERT_OK(ImageSegmenter::Create(std::move(options)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
options->base_options.model_file_name =
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata));
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
|
options->base_options.op_resolver =
|
||||||
auto segmenter_or = ImageSegmenter::Create(
|
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||||
std::move(options), absl::make_unique<DeepLabOpResolverMissingOps>());
|
auto segmenter_or = ImageSegmenter::Create(std::move(options));
|
||||||
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||||
|
@ -202,24 +199,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) {
|
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata));
|
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
|
||||||
SegmenterOptions::UNSPECIFIED);
|
|
||||||
|
|
||||||
auto segmenter_or = ImageSegmenter::Create(
|
|
||||||
std::move(options), absl::make_unique<DeepLabOpResolver>());
|
|
||||||
|
|
||||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
|
||||||
EXPECT_THAT(segmenter_or.status().message(),
|
|
||||||
HasSubstr("`output_type` must not be UNSPECIFIED"));
|
|
||||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
|
||||||
Optional(absl::Cord(absl::StrCat(
|
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
|
||||||
}
|
|
||||||
|
|
||||||
class SegmentationTest : public tflite_shims::testing::Test {};
|
class SegmentationTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(SegmentationTest, SucceedsWithCategoryMask) {
|
TEST_F(SegmentationTest, SucceedsWithCategoryMask) {
|
||||||
|
@ -228,10 +207,10 @@ TEST_F(SegmentationTest, SucceedsWithCategoryMask) {
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
"segmentation_input_rotation0.jpg")));
|
"segmentation_input_rotation0.jpg")));
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
options->base_options.model_file_name =
|
||||||
SegmenterOptions::CATEGORY_MASK);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image));
|
||||||
|
@ -253,12 +232,11 @@ TEST_F(SegmentationTest, SucceedsWithConfidenceMask) {
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
options->base_options.model_file_name =
|
||||||
SegmenterOptions::CONFIDENCE_MASK);
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
options->mutable_segmenter_options()->set_activation(
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
SegmenterOptions::SOFTMAX);
|
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
|
||||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
|
||||||
|
@ -281,17 +259,15 @@ TEST_F(SegmentationTest, SucceedsSelfie128x128Segmentation) {
|
||||||
Image image =
|
Image image =
|
||||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
options->base_options.model_file_name =
|
||||||
SegmenterOptions::CONFIDENCE_MASK);
|
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
||||||
options->mutable_segmenter_options()->set_activation(
|
options->base_options.op_resolver =
|
||||||
SegmenterOptions::SOFTMAX);
|
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata));
|
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(
|
ImageSegmenter::Create(std::move(options)));
|
||||||
std::move(options),
|
|
||||||
absl::make_unique<SelfieSegmentationModelOpResolver>()));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 2);
|
EXPECT_EQ(confidence_masks.size(), 2);
|
||||||
|
|
||||||
|
@ -313,15 +289,14 @@ TEST_F(SegmentationTest, SucceedsSelfie144x256Segmentations) {
|
||||||
Image image =
|
Image image =
|
||||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
||||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
options->mutable_segmenter_options()->set_output_type(
|
options->base_options.model_file_name =
|
||||||
SegmenterOptions::CONFIDENCE_MASK);
|
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
||||||
options->mutable_base_options()->mutable_model_file()->set_file_name(
|
options->base_options.op_resolver =
|
||||||
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata));
|
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
std::unique_ptr<ImageSegmenter> segmenter,
|
options->activation = ImageSegmenterOptions::Activation::NONE;
|
||||||
ImageSegmenter::Create(
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
std::move(options),
|
ImageSegmenter::Create(std::move(options)));
|
||||||
absl::make_unique<SelfieSegmentationModelOpResolver>()));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 1);
|
EXPECT_EQ(confidence_masks.size(), 1);
|
||||||
|
|
30
mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD
Normal file
30
mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "image_segmenter_options_proto",
|
||||||
|
srcs = ["image_segmenter_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components:segmenter_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
|
],
|
||||||
|
)
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks;
|
package mediapipe.tasks.vision.image_segmenter.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/segmenter_options.proto";
|
import "mediapipe/tasks/cc/components/segmenter_options.proto";
|
||||||
|
@ -25,7 +25,7 @@ message ImageSegmenterOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ImageSegmenterOptions ext = 458105758;
|
optional ImageSegmenterOptions ext = 458105758;
|
||||||
}
|
}
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
|
@ -36,10 +36,19 @@ namespace vision {
|
||||||
|
|
||||||
// The options for configuring a mediapipe object detector task.
|
// The options for configuring a mediapipe object detector task.
|
||||||
struct ObjectDetectorOptions {
|
struct ObjectDetectorOptions {
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, op resolver, etc.
|
// model file with metadata, accelerator options, op resolver, etc.
|
||||||
tasks::core::BaseOptions base_options;
|
tasks::core::BaseOptions base_options;
|
||||||
|
|
||||||
|
// The running mode of the task. Default to the image mode.
|
||||||
|
// Object detector has three running modes:
|
||||||
|
// 1) The image mode for detecting objects on single image inputs.
|
||||||
|
// 2) The video mode for detecting objects on the decoded frames of a video.
|
||||||
|
// 3) The live stream mode for detecting objects on the live stream of input
|
||||||
|
// data, such as from camera. In this mode, the "result_callback" below must
|
||||||
|
// be specified to receive the detection results asynchronously.
|
||||||
|
core::RunningMode running_mode = core::RunningMode::IMAGE;
|
||||||
|
|
||||||
// The locale to use for display names specified through the TFLite Model
|
// The locale to use for display names specified through the TFLite Model
|
||||||
// Metadata, if any. Defaults to English.
|
// Metadata, if any. Defaults to English.
|
||||||
std::string display_names_locale = "en";
|
std::string display_names_locale = "en";
|
||||||
|
@ -65,15 +74,6 @@ struct ObjectDetectorOptions {
|
||||||
// category names are ignored. Mutually exclusive with category_allowlist.
|
// category names are ignored. Mutually exclusive with category_allowlist.
|
||||||
std::vector<std::string> category_denylist = {};
|
std::vector<std::string> category_denylist = {};
|
||||||
|
|
||||||
// The running mode of the task. Default to the image mode.
|
|
||||||
// Object detector has three running modes:
|
|
||||||
// 1) The image mode for detecting objects on single image inputs.
|
|
||||||
// 2) The video mode for detecting objects on the decoded frames of a video.
|
|
||||||
// 3) The live stream mode for detecting objects on the live stream of input
|
|
||||||
// data, such as from camera. In this mode, the "result_callback" below must
|
|
||||||
// be specified to receive the detection results asynchronously.
|
|
||||||
core::RunningMode running_mode = core::RunningMode::IMAGE;
|
|
||||||
|
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
|
|
|
@ -27,7 +27,7 @@ message ObjectDetectorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ObjectDetectorOptions ext = 443442058;
|
optional ObjectDetectorOptions ext = 443442058;
|
||||||
}
|
}
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||||
// model file with metadata, accelerator options, etc.
|
// model file with metadata, accelerator options, etc.
|
||||||
optional core.proto.BaseOptions base_options = 1;
|
optional core.proto.BaseOptions base_options = 1;
|
||||||
|
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h"
|
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
|
||||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
|
||||||
|
|
||||||
namespace mediapipe {
|
|
||||||
namespace tasks {
|
|
||||||
namespace vision {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
|
||||||
constexpr char kImageStreamName[] = "image_in";
|
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
|
||||||
constexpr char kSubgraphTypeName[] =
|
|
||||||
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
|
||||||
|
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
|
||||||
using ::mediapipe::Image;
|
|
||||||
|
|
||||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
|
||||||
// "mediapipe.tasks.vision.SegmenterGraph".
|
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
|
||||||
std::unique_ptr<ImageSegmenterOptions> options) {
|
|
||||||
api2::builder::Graph graph;
|
|
||||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
|
||||||
subgraph.GetOptions<ImageSegmenterOptions>().Swap(options.get());
|
|
||||||
graph.In(kImageTag).SetName(kImageStreamName) >> subgraph.In(kImageTag);
|
|
||||||
subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
|
||||||
graph.Out(kGroupedSegmentationTag);
|
|
||||||
return graph.GetConfig();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|
||||||
std::unique_ptr<ImageSegmenterOptions> options,
|
|
||||||
std::unique_ptr<tflite::OpResolver> resolver) {
|
|
||||||
return core::TaskApiFactory::Create<ImageSegmenter, ImageSegmenterOptions>(
|
|
||||||
CreateGraphConfig(std::move(options)), std::move(resolver));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
|
||||||
mediapipe::Image image) {
|
|
||||||
if (image.UsesGpu()) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
|
||||||
}
|
|
||||||
ASSIGN_OR_RETURN(
|
|
||||||
auto output_packets,
|
|
||||||
runner_->Process({{kImageStreamName,
|
|
||||||
mediapipe::MakePacket<Image>(std::move(image))}}));
|
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vision
|
|
||||||
} // namespace tasks
|
|
||||||
} // namespace mediapipe
|
|
|
@ -1,76 +0,0 @@
|
||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_
|
|
||||||
#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "mediapipe/framework/formats/image.h"
|
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
|
||||||
#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h"
|
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
|
|
||||||
namespace mediapipe {
|
|
||||||
namespace tasks {
|
|
||||||
namespace vision {
|
|
||||||
|
|
||||||
// Performs segmentation on images.
|
|
||||||
//
|
|
||||||
// The API expects a TFLite model with mandatory TFLite Model Metadata.
|
|
||||||
//
|
|
||||||
// Input tensor:
|
|
||||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
|
||||||
// - image input of size `[batch x height x width x channels]`.
|
|
||||||
// - batch inference is not supported (`batch` is required to be 1).
|
|
||||||
// - RGB and greyscale inputs are supported (`channels` is required to be
|
|
||||||
// 1 or 3).
|
|
||||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
|
||||||
// attached to the metadata for input normalization.
|
|
||||||
// Output tensors:
|
|
||||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
|
||||||
// - list of segmented masks.
|
|
||||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
|
||||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
|
||||||
// `cahnnels`.
|
|
||||||
// - batch is always 1
|
|
||||||
// An example of such model can be found at:
|
|
||||||
// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
|
|
||||||
class ImageSegmenter : core::BaseTaskApi {
|
|
||||||
public:
|
|
||||||
using BaseTaskApi::BaseTaskApi;
|
|
||||||
|
|
||||||
// Creates a Segmenter from the provided options. A non-default
|
|
||||||
// OpResolver can be specified in order to support custom Ops or specify a
|
|
||||||
// subset of built-in Ops.
|
|
||||||
static absl::StatusOr<std::unique_ptr<ImageSegmenter>> Create(
|
|
||||||
std::unique_ptr<ImageSegmenterOptions> options,
|
|
||||||
std::unique_ptr<tflite::OpResolver> resolver =
|
|
||||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
|
||||||
|
|
||||||
// Runs the actual segmentation task.
|
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace vision
|
|
||||||
} // namespace tasks
|
|
||||||
} // namespace mediapipe
|
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_
|
|
1
mediapipe/tasks/testdata/vision/BUILD
vendored
1
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -80,6 +80,7 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO Create individual filegroup for models required for each Tasks.
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "test_models",
|
name = "test_models",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
Loading…
Reference in New Issue
Block a user