Merge branch 'master' into face-landmarker-python

This commit is contained in:
Kinar R 2023-03-16 12:17:59 +05:30 committed by GitHub
commit 15d90bd325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 3833 additions and 138 deletions

17
LICENSE
View File

@ -199,3 +199,20 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
===========================================================================
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
===========================================================================
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/

View File

@ -96,6 +96,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows. # TODO: Build text_classifier_graph and text_embedder_graph on Windows.

View File

@ -42,3 +42,36 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:test_util", "@org_tensorflow//tensorflow/lite/kernels:test_util",
], ],
) )
cc_library(
name = "ngram_hash",
srcs = ["ngram_hash.cc"],
hdrs = ["ngram_hash.h"],
copts = tflite_copts(),
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
alwayslink = 1,
)
cc_test(
name = "ngram_hash_test",
srcs = ["ngram_hash_test.cc"],
deps = [
":ngram_hash",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@com_google_absl//absl/types:optional",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)

View File

@ -0,0 +1,264 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <string>
#include <vector>
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace ngram_op {
namespace {
using ::flexbuffers::GetRoot;
using ::flexbuffers::Map;
using ::flexbuffers::TypedVector;
using ::mediapipe::tasks::text::language_detector::custom_ops::
LowercaseUnicodeStr;
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
constexpr int kDefaultMaxSplits = 128;
// This op takes in a string, finds the character ngrams for it and then
// maps each of these ngrams to an index using the specified vocabulary sizes.
// Input(s):
// - input: Input string.
// - seeds: Seed for the random number generator.
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
// be interpreted as generating unigrams, bigrams, and trigrams.
// - vocab_sizes: Size of the vocabulary for each of the ngram features
// respectively. The op would generate vocab ids to be less than or equal to
// the vocab size. The index 0 implies an invalid ngram.
// - max_splits: Maximum number of tokens in the output. If this is unset, the
// limit is `kDefaultMaxSplits`.
// - lower_case_input: If this is set to true, the input string would be
// lower-cased before any processing.
// Output(s):
// - output: A tensor of size [number of ngrams, number of tokens + 2],
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
// Helper class used for pre-processing the input.
class NGramHashParams {
public:
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes, int max_splits,
bool lower_case_input)
: seed_(seed),
ngram_lengths_(ngram_lengths),
vocab_sizes_(vocab_sizes),
max_splits_(max_splits),
lower_case_input_(lower_case_input) {}
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
// Do sanity checks on the input.
if (ngram_lengths_.empty()) {
context->ReportError(context, "`ngram_lengths` must be non-empty.");
return kTfLiteError;
}
if (vocab_sizes_.empty()) {
context->ReportError(context, "`vocab_sizes` must be non-empty.");
return kTfLiteError;
}
if (ngram_lengths_.size() != vocab_sizes_.size()) {
context->ReportError(
context,
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
return kTfLiteError;
}
if (max_splits_ <= 0) {
context->ReportError(context, "`max_splits` must be > 0.");
return kTfLiteError;
}
// Obtain and tokenize the input.
StringRef inputref = GetString(input_t, /*string_index=*/0);
if (lower_case_input_) {
std::string lower_cased_str;
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
tokenized_output_ =
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
} else {
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
}
return kTfLiteOk;
}
uint64_t GetSeed() const { return seed_; }
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
int GetNumNGrams() const { return ngram_lengths_.size(); }
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
const TokenizedOutput& GetTokenizedOutput() const {
return tokenized_output_;
}
TokenizedOutput tokenized_output_;
private:
const uint64_t seed_;
std::vector<int> ngram_lengths_;
std::vector<int> vocab_sizes_;
const int max_splits_;
const bool lower_case_input_;
};
// Convert the TypedVector into a regular std::vector.
std::vector<int> GetIntVector(TypedVector typed_vec) {
std::vector<int> vec(typed_vec.size());
for (int j = 0; j < typed_vec.size(); j++) {
vec[j] = typed_vec[j].AsInt32();
}
return vec;
}
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
const int max_unicode_length = params->GetNumTokens();
const auto ngram_lengths = params->GetNGramLengths();
const auto vocab_sizes = params->GetVocabSizes();
const auto& tokenized_output = params->GetTokenizedOutput();
const auto seed = params->GetSeed();
// Compute for each ngram.
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
const int vocab_size = vocab_sizes[ngram];
const int ngram_length = ngram_lengths[ngram];
// Compute for each token within the input.
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
// Compute the number of bytes for the ngram starting at the given
// token.
int num_bytes = 0;
for (int i = start;
i < tokenized_output.tokens.size() && i < (start + ngram_length);
i++) {
num_bytes += tokenized_output.tokens[i].second;
}
// Compute the hash for the ngram starting at the token.
const auto str_hash = MurmurHash64WithSeed(
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
num_bytes, seed);
// Map the hash to an index in the vocab.
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
}
}
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const Map& m = GetRoot(buffer_t, length).AsMap();
const uint64_t seed = m["seed"].AsUInt64();
const std::vector<int> ngram_lengths =
GetIntVector(m["ngram_lengths"].AsTypedVector());
const std::vector<int> vocab_sizes =
GetIntVector(m["vocab_sizes"].AsTypedVector());
const int max_splits =
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
const bool lowercase_input =
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
lowercase_input);
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<NGramHashParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context,
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
if (IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = 1;
output_size->data[1] = params->GetNumNGrams();
output_size->data[2] = params->GetNumTokens();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteInt32) {
GetNGramHashIndices(params, output->data.i32);
} else {
context->ReportError(context, "Output type must be Int32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace ngram_op
TfLiteRegistration* Register_NGRAM_HASH() {
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
ngram_op::Resize, ngram_op::Eval};
return &r;
}
} // namespace tflite::ops::custom

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_NGRAM_HASH();
} // namespace tflite::ops::custom
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_

View File

@ -0,0 +1,313 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace {
using ::flexbuffers::Builder;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::testing::ElementsAreArray;
using ::testing::Message;
// Helper class for testing the op.
class NGramHashModel : public SingleOpModel {
public:
explicit NGramHashModel(const uint64_t seed,
const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes,
const absl::optional<int> max_splits = std::nullopt) {
// Setup the model inputs.
Builder fbb;
size_t start = fbb.StartMap();
fbb.UInt("seed", seed);
{
size_t start = fbb.StartVector("ngram_lengths");
for (const int& ngram_len : ngram_lengths) {
fbb.Int(ngram_len);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
{
size_t start = fbb.StartVector("vocab_sizes");
for (const int& vocab_size : vocab_sizes) {
fbb.Int(vocab_size);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
if (max_splits) {
fbb.Int("max_splits", *max_splits);
}
fbb.EndMap(start);
fbb.Finish();
output_ = AddOutput({TensorType_INT32, {}});
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
BuildInterpreter({GetShape(input_)});
}
void SetupInputTensor(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
}
void Invoke(const std::string& input) {
SetupInputTensor(input);
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeUnchecked(const std::string& input) {
SetupInputTensor(input);
return SingleOpModel::Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_ = AddInput(TensorType_STRING);
int output_;
};
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
// Checks that the op returns the expected value when the input is sane.
// Also checks that when `max_splits` is not specified, the entire string is
// tokenized.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::vector<std::string> testcase_inputs({
"hi",
"wow",
"!",
"HI",
});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "hi".
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow".
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "!" (which will get replaced by " ").
hash("^", 0),
hash(" ", 0),
hash("$", 0),
hash("^ ", 1),
hash(" $", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "HI" (which will get lower-cased).
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
}});
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
const string& testcase_input = testcase_inputs[test_idx];
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where the testcases' input is: "
<< testcase_input);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
}
}
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
// Checks that the op returns the expected value when the input is correct
// when `max_splits` is specified.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::string testcase_input = "wow";
const std::vector<int> max_splits({2, 3, 4, 5, 6});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "wow", when `max_splits` == 2.
// We cannot include any of the actual tokens, since `max_splits`
// only allows enough space for the delimiters.
hash("^", 0),
hash("$", 0),
hash("^$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 3.
// We can start to include some tokens from the input string.
hash("^", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 4.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("o$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 5.
// We can include the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 6.
// `max_splits` is more than the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
}});
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
const int testcase_max_splits = max_splits[test_idx];
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(
m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
std::min(
// Longest possible tokenization when using the entire
// input.
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
// Longest possible string when the `max_splits` value
// is < testcase_input.size() + 2 for padding.
testcase_max_splits)}));
}
}
TEST(NGramHashTest, InvalidMaxSplitsValue) {
// Check that the op errors out when given an invalid max splits value.
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
for (const int max_splits : invalid_max_splits) {
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
/*vocab_sizes=*/{1, 2});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2, 3});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
} // namespace
} // namespace tflite::ops::custom

View File

@ -0,0 +1,42 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "ngram_hash_ops_utils",
srcs = [
"ngram_hash_ops_utils.cc",
],
hdrs = [
"ngram_hash_ops_utils.h",
],
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
],
)
cc_test(
name = "ngram_hash_ops_utils_test",
size = "small",
srcs = [
"ngram_hash_ops_utils_test.cc",
],
deps = [
":ngram_hash_ops_utils",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,96 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens) {
const std::string kPrefix = "^";
const std::string kSuffix = "$";
const std::string kReplacementToken = " ";
TokenizedOutput output;
size_t token_start = 0;
output.str.reserve(len + 2);
output.tokens.reserve(len + 2);
output.str.append(kPrefix);
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
token_start += kPrefix.size();
Rune token;
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
// Use the standard UTF-8 library to find the next token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
// Stop processing, if we can't read any more tokens, or we have reached
// maximum allowed tokens, allocating one token for the suffix.
if (bytes_read == 0) {
break;
}
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
// alphanumeric, replace it with a replacement token.
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
output.str.append(kReplacementToken);
output.tokens.push_back(
std::make_pair(token_start, kReplacementToken.size()));
token_start += kReplacementToken.size();
i += bytes_read;
continue;
}
// Append the token in the output string, and note its position and the
// number of bytes that token consumed.
output.str.append(input_str + i, bytes_read);
output.tokens.push_back(std::make_pair(token_start, bytes_read));
token_start += bytes_read;
i += bytes_read;
}
output.str.append(kSuffix);
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
token_start += kSuffix.size();
return output;
}
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str) {
for (int i = 0; i < len;) {
Rune token;
// Tokenize the given string, and get the appropriate lowercase token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
// Write back the token to the output string.
char token_buf[UTFmax];
size_t bytes_to_write = utf_runetochar(token_buf, &token);
output_str->append(token_buf, bytes_to_write);
i += bytes_read;
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#include <string>
#include <utility>
#include <vector>
namespace mediapipe::tasks::text::language_detector::custom_ops {
struct TokenizedOutput {
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
std::string str;
// This vector contains pairs, where each pair has two members. The first
// denoting the starting index of the token in the `str` string, and the
// second denoting the length of that token in bytes.
std::vector<std::pair<const size_t, const size_t>> tokens;
};
// Tokenizes the given input string on Unicode token boundaries, with a maximum
// of `max_tokens` tokens.
//
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
//
// The method returns the output in the `TokenizedOutput` struct, which stores
// both, the processed input string, and the indices and sizes of each token
// within that string.
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens);
// Converts the given unicode string (`input_str`) with the specified length
// (`len`) to a lowercase string.
//
// The method populates the lowercased string in `output_str`.
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str);
} // namespace mediapipe::tasks::text::language_detector::custom_ops
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
namespace {
using ::testing::Values;
std::string ReconstructStringFromTokens(TokenizedOutput output) {
std::string reconstructed_str;
for (int i = 0; i < output.tokens.size(); i++) {
reconstructed_str.append(
output.str.c_str() + output.tokens[i].first,
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
}
return reconstructed_str;
}
struct TokenizeTestParams {
std::string input_str;
size_t max_tokens;
bool exclude_nonalphaspace_tokens;
std::string expected_output_str;
};
class TokenizeParameterizedTest
: public ::testing::Test,
public testing::WithParamInterface<TokenizeTestParams> {};
TEST_P(TokenizeParameterizedTest, Tokenize) {
// Checks that the Tokenize method returns the expected value.
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
const TokenizedOutput output = Tokenize(
/*input_str=*/params.input_str.c_str(),
/*len=*/params.input_str.size(),
/*max_tokens=*/params.max_tokens,
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
// The output string should have the necessary prefixes, and the "!" token
// should have been replaced with a " ".
EXPECT_EQ(output.str, params.expected_output_str);
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
}
INSTANTIATE_TEST_SUITE_P(
TokenizeParameterizedTests, TokenizeParameterizedTest,
Values(
// Test including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/false,
/*expected_output_str=*/"^hi!$"}),
// Test not including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^hi $"}),
// Test with a maximum of 3 tokens.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^h$"}),
// Test with non-latin characters.
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^ありがと$"})));
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
{
// Check that the method is a no-op when the string is lowercase.
std::string input_str = "hello";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method has uppercase characters.
std::string input_str = "hElLo";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method works with non-latin scripts.
// Cyrillic has the concept of cases, so it should change the input.
std::string input_str = "БЙп";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "бйп");
}
{
// Check that the method works with non-latin scripts.
// Japanese doesn't have the concept of cases, so it should not change.
std::string input_str = "ありがと";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "ありがと");
}
}
} // namespace
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "utf",
srcs = [
"rune.c",
"runetype.c",
"runetypebody.h",
],
hdrs = ["utf.h"],
)

View File

@ -0,0 +1,233 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include <stdarg.h>
#include <string.h>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
enum
{
Bit1 = 7,
Bitx = 6,
Bit2 = 5,
Bit3 = 4,
Bit4 = 3,
Bit5 = 2,
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
Rune4 = (1<<(Bit4+3*Bitx))-1,
/* 0001 1111 1111 1111 1111 1111 */
Maskx = (1<<Bitx)-1, /* 0011 1111 */
Testx = Maskx ^ 0xFF, /* 1100 0000 */
Bad = Runeerror,
};
/*
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
* This is a slower but "safe" version of the old chartorune
* that works on strings that are not necessarily null-terminated.
*
* If you know for sure that your string is null-terminated,
* chartorune will be a bit faster.
*
* It is guaranteed not to attempt to access "length"
* past the incoming pointer. This is to avoid
* possible access violations. If the string appears to be
* well-formed but incomplete (i.e., to get the whole Rune
* we'd need to read past str+length) then we'll set the Rune
* to Bad and return 0.
*
* Note that if we have decoding problems for other
* reasons, we return 1 instead of 0.
*/
int
utf_charntorune(Rune *rune, const char *str, int length)
{
int c, c1, c2, c3;
long l;
/* When we're not allowed to read anything */
if(length <= 0) {
goto badlen;
}
/*
* one character sequence (7-bit value)
* 00000-0007F => T1
*/
c = *(uchar*)str;
if(c < Tx) {
*rune = c;
return 1;
}
// If we can't read more than one character we must stop
if(length <= 1) {
goto badlen;
}
/*
* two character sequence (11-bit value)
* 0080-07FF => T2 Tx
*/
c1 = *(uchar*)(str+1) ^ Tx;
if(c1 & Testx)
goto bad;
if(c < T3) {
if(c < T2)
goto bad;
l = ((c << Bitx) | c1) & Rune2;
if(l <= Rune1)
goto bad;
*rune = l;
return 2;
}
// If we can't read more than two characters we must stop
if(length <= 2) {
goto badlen;
}
/*
* three character sequence (16-bit value)
* 0800-FFFF => T3 Tx Tx
*/
c2 = *(uchar*)(str+2) ^ Tx;
if(c2 & Testx)
goto bad;
if(c < T4) {
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
if(l <= Rune2)
goto bad;
*rune = l;
return 3;
}
if (length <= 3)
goto badlen;
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
c3 = *(uchar*)(str+3) ^ Tx;
if (c3 & Testx)
goto bad;
if (c < T5) {
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
if (l <= Rune3)
goto bad;
if (l > Runemax)
goto bad;
*rune = l;
return 4;
}
// Support for 5-byte or longer UTF-8 would go here, but
// since we don't have that, we'll just fall through to bad.
/*
* bad decoding
*/
bad:
*rune = Bad;
return 1;
badlen:
*rune = Bad;
return 0;
}
int
utf_runetochar(char *str, const Rune *rune)
{
/* Runes are signed, so convert to unsigned for range check. */
unsigned long c;
/*
* one character sequence
* 00000-0007F => 00-7F
*/
c = *rune;
if(c <= Rune1) {
str[0] = c;
return 1;
}
/*
* two character sequence
* 0080-07FF => T2 Tx
*/
if(c <= Rune2) {
str[0] = T2 | (c >> 1*Bitx);
str[1] = Tx | (c & Maskx);
return 2;
}
/*
* If the Rune is out of range, convert it to the error rune.
* Do this test here because the error rune encodes to three bytes.
* Doing it earlier would duplicate work, since an out of range
* Rune wouldn't have fit in one or two bytes.
*/
if (c > Runemax)
c = Runeerror;
/*
* three character sequence
* 0800-FFFF => T3 Tx Tx
*/
if (c <= Rune3) {
str[0] = T3 | (c >> 2*Bitx);
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
str[2] = Tx | (c & Maskx);
return 3;
}
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
str[0] = T4 | (c >> 3*Bitx);
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
str[3] = Tx | (c & Maskx);
return 4;
}

View File

@ -0,0 +1,54 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
static
Rune*
rbsearch(Rune c, Rune *t, int n, int ne)
{
Rune *p;
int m;
while(n > 1) {
m = n >> 1;
p = t + m*ne;
if(c >= p[0]) {
t = p;
n = n-m;
} else
n = m;
}
if(n && c >= t[0])
return t;
return 0;
}
#define RUNETYPEBODY
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h"

View File

@ -0,0 +1,212 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef RUNETYPEBODY
static Rune __isalphar[] = {
0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6,
0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374,
0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1,
0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556,
0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a,
0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef,
0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea,
0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac,
0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f,
0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0,
0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1,
0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30,
0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c,
0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8,
0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1,
0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30,
0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61,
0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a,
0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9,
0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33,
0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c,
0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9,
0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10,
0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96,
0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30,
0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88,
0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab,
0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf,
0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a,
0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070,
0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248,
0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288,
0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be,
0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315,
0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c,
0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c,
0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c,
0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8,
0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974,
0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54,
0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf,
0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d,
0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf,
0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d,
0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc,
0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb,
0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c,
0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139,
0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e,
0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3,
0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6,
0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6,
0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006,
0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f,
0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e,
0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc,
0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f,
0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5,
0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793,
0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a,
0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7,
0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2,
0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76,
0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd,
0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e,
0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2,
0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d,
0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28,
0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44,
0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7,
0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a,
0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf,
0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026,
0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d,
0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e,
0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3,
0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835,
0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939,
0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17,
0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55,
0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af,
0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4,
0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38,
0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454,
0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac,
0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a,
0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e,
0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0,
0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734,
0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8,
0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f,
0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f,
0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72,
0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b,
0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6,
0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d,
};
static Rune __isalphas[] = {
0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559,
0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828,
0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd,
0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd,
0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5,
0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7,
0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59,
0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115,
0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f,
0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e,
0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24,
0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54,
0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e,
};
int utf_isalpharune(Rune c) {
Rune *p;
p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2);
if (p && c >= p[0] && c <= p[1]) return 1;
p = rbsearch(c, __isalphas, nelem(__isalphas), 1);
if (p && c == p[0]) return 1;
return 0;
}
static Rune __tolowerr[] = {
0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608,
0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613,
0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608,
0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608,
0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568,
0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568,
0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568,
0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568,
0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568,
0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464,
0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592,
0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761,
0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616,
};
static Rune __tolowerp[] = {
0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577,
0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577,
0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577,
0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577,
0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577,
0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577,
0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568,
0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577,
0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577,
0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577,
0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577,
};
static Rune __tolowers[] = {
0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786,
0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577,
0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779,
0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787,
0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789,
0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577,
0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577,
0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578,
0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578,
0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577,
0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371,
0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577,
0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577,
0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584,
0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577,
0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840,
0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569,
0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314,
0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833,
0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827,
0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577,
0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577,
0xa78d, 1006296, 0xa7aa, 1006268,
};
Rune utf_tolowerrune(Rune c) {
Rune *p;
p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3);
if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576;
p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3);
if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1))
return c + p[2] - 1048576;
p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2);
if (p && c == p[0]) return c + p[1] - 1048576;
return c;
}
#endif

View File

@ -0,0 +1,98 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Fork of several UTF utils originally written by Rob Pike and Ken Thompson.
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1
#include <stdint.h>
// Code-point values in Unicode 4.0 are 21 bits wide.
typedef signed int Rune;
#define uchar _utfuchar
typedef unsigned char uchar;
#define nelem(x) (sizeof(x) / sizeof((x)[0]))
enum {
UTFmax = 4, // maximum bytes per rune
Runeerror = 0xFFFD, // decoding error in UTF
Runemax = 0x10FFFF, // maximum rune value
};
#ifdef __cplusplus
extern "C" {
#endif
/*
* rune routines
*/
/*
* These routines were written by Rob Pike and Ken Thompson
* and first appeared in Plan 9.
* SEE ALSO
* utf (7)
* tcs (1)
*/
// utf_runetochar copies (encodes) one rune, pointed to by r, to at most
// UTFmax bytes starting at s and returns the number of bytes generated.
int utf_runetochar(char* s, const Rune* r);
// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to
// one rune, pointed to by `rune`, accesss at most `length` bytes of `str`, and
// returns the number of bytes consumed.
// If the UTF sequence is incomplete within n bytes,
// utf_charntorune will set *r to Runeerror and return 0. If it is complete
// but not in UTF format, it will set *r to Runeerror and return 1.
//
// Added 2004-09-24 by Wei-Hwa Huang
int utf_charntorune(Rune* rune, const char* str, int length);
// Unicode defines some characters as letters and
// specifies three cases: upper, lower, and title. Mappings among the
// cases are also defined, although they are not exhaustive: some
// upper case letters have no lower case mapping, and so on. Unicode
// also defines several character properties, a subset of which are
// checked by these routines. These routines are based on Unicode
// version 3.0.0.
//
// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false
// and 1 for true.
//
// utf_tolowerrune is the Unicode case mapping. It returns the character
// unchanged if it has no defined mapping.
Rune utf_tolowerrune(Rune r);
// utf_isalpharune tests for Unicode letters; this includes ideographs in
// addition to alphabetic characters.
int utf_isalpharune(Rune r);
// (The comments in this file were copied from the manpage files rune.3,
// isalpharune.3, and runestrcat.3. Some formatting changes were also made
// to conform to Google style. /JRM 11/11/05)
#ifdef __cplusplus
}
#endif
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_

View File

@ -80,6 +80,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util", "//mediapipe/util:label_map_util",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <limits>
#include <memory> #include <memory>
#include <ostream> #include <ostream>
#include <string> #include <string>
@ -79,6 +80,133 @@ void Sigmoid(absl::Span<const float> values,
[](float value) { return 1. / (1 + std::exp(-value)); }); [](float value) { return 1. / (1 + std::exp(-value)); });
} }
std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
const Shape& output_shape,
const SegmenterOptions& options,
const float* tensors_buffer) {
cv::Mat resized_tensors_mat;
cv::Mat tensors_mat_view(
input_shape.height, input_shape.width, CV_32FC(input_shape.channels),
reinterpret_cast<void*>(const_cast<float*>(tensors_buffer)));
if (output_shape.height == input_shape.height &&
output_shape.width == input_shape.width) {
resized_tensors_mat = tensors_mat_view;
} else {
// Resize input tensors to output size.
// TOOD(b/273633027) Use an efficient way to find values for category mask
// instead of resizing the whole tensor .
cv::resize(tensors_mat_view, resized_tensors_mat,
{output_shape.width, output_shape.height}, 0, 0,
cv::INTER_LINEAR);
}
// Category mask Image.
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
ImageFormat::GRAY8, output_shape.width, output_shape.height, 1);
Image category_mask(image_frame_ptr);
// Fill in the maximum category in the category mask image.
cv::Mat category_mask_mat_view =
mediapipe::formats::MatView(image_frame_ptr.get());
const int input_channels = input_shape.channels;
category_mask_mat_view.forEach<uint8_t>(
[&resized_tensors_mat, &input_channels, &options](uint8_t& pixel,
const int position[]) {
float* tensors_buffer =
resized_tensors_mat.ptr<float>(position[0], position[1]);
absl::Span<float> confidence_scores(tensors_buffer, input_channels);
// Only process the activation function if it is SIGMOID. If NONE,
// we do nothing for activation, If SOFTMAX, it is required
// to have input_channels > 1, and for input_channels > 1, we don't need
// activation to find the maximum value.
if (options.activation() == SegmenterOptions::SIGMOID) {
Sigmoid(confidence_scores, confidence_scores);
}
if (input_channels == 1) {
// if the input tensor is a single mask, it is assumed to be a binary
// foreground segmentation mask. For such a mask, we make foreground
// category 1, and background category 0.
pixel = static_cast<uint8_t>(*tensors_buffer > 0.5f);
} else {
const int maximum_category_idx =
std::max_element(confidence_scores.begin(),
confidence_scores.end()) -
confidence_scores.begin();
pixel = maximum_category_idx;
}
});
return {category_mask};
}
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
const Shape& output_shape,
const SegmenterOptions& options,
const float* tensors_buffer) {
std::function<void(absl::Span<const float> values,
absl::Span<float> activated_values)>
activation_fn;
switch (options.activation()) {
case SegmenterOptions::SIGMOID:
activation_fn = &Sigmoid;
break;
case SegmenterOptions::SOFTMAX:
activation_fn = &StableSoftmax;
break;
case SegmenterOptions::NONE:
// Just copying for NONE activation.
activation_fn = [](absl::Span<const float> values,
absl::Span<float> activated_values) {
std::copy(values.begin(), values.end(), activated_values.begin());
};
break;
}
// TODO Use libyuv for resizing instead.
std::vector<Image> confidence_masks;
std::vector<cv::Mat> confidence_mask_mats;
confidence_masks.reserve(input_shape.channels);
confidence_mask_mats.reserve(input_shape.channels);
for (int i = 0; i < input_shape.channels; ++i) {
confidence_masks.push_back(Image(std::make_shared<ImageFrame>(
ImageFormat::VEC32F1, input_shape.width, input_shape.height, 1)));
confidence_mask_mats.push_back(mediapipe::formats::MatView(
confidence_masks.back().GetImageFrameSharedPtr().get()));
}
// Applies activation function.
const int tensor_size = input_shape.height * input_shape.width;
std::vector<float> activated_values(input_shape.channels);
absl::Span<float> activated_values_span(activated_values);
for (int i = 0; i < tensor_size; ++i) {
activation_fn(absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels],
input_shape.channels),
activated_values_span);
for (int j = 0; j < input_shape.channels; ++j) {
confidence_mask_mats[j].at<float>(
i / input_shape.width, i % input_shape.width) = activated_values[j];
}
}
if (output_shape.height == input_shape.height &&
output_shape.width == input_shape.width) {
return confidence_masks;
}
std::vector<Image> resized_confidence_masks;
resized_confidence_masks.reserve(confidence_mask_mats.size());
// Resizes segmented masks to required output size.
for (int i = 0; i < confidence_mask_mats.size(); i++) {
// Pre-allocates ImageFrame memory to avoid copying from cv::Mat
// afterward.
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
ImageFormat::VEC32F1, output_shape.width, output_shape.height, 1);
cv::Mat resized_mask_mat_view =
mediapipe::formats::MatView(image_frame_ptr.get());
cv::resize(confidence_mask_mats[i], resized_mask_mat_view,
resized_mask_mat_view.size(), 0, 0, cv::INTER_LINEAR);
resized_confidence_masks.push_back(Image(image_frame_ptr));
}
return resized_confidence_masks;
}
} // namespace } // namespace
// Converts Tensors from a vector of Tensor to Segmentation. // Converts Tensors from a vector of Tensor to Segmentation.
@ -222,81 +350,16 @@ absl::Status TensorsToSegmentationCalculator::Process(
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu( std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
const Shape& input_shape, const Shape& output_shape, const Shape& input_shape, const Shape& output_shape,
const float* tensors_buffer) { const float* tensors_buffer) {
std::function<void(absl::Span<const float> values, if (options_.segmenter_options().output_type() ==
absl::Span<float> activated_values)> SegmenterOptions::CATEGORY_MASK) {
activation_fn; return ProcessForCategoryMaskCpu(input_shape, output_shape,
switch (options_.segmenter_options().activation()) { options_.segmenter_options(),
case SegmenterOptions::SIGMOID: tensors_buffer);
activation_fn = &Sigmoid;
break;
case SegmenterOptions::SOFTMAX:
activation_fn = &StableSoftmax;
break;
case SegmenterOptions::NONE:
// Just copying for NONE activation.
activation_fn = [](absl::Span<const float> values,
absl::Span<float> activated_values) {
std::copy(values.begin(), values.end(), activated_values.begin());
};
break;
}
const bool is_category_mask = options_.segmenter_options().output_type() ==
SegmenterOptions::CATEGORY_MASK;
const int cv_mat_type = is_category_mask ? CV_8UC1 : CV_32FC1;
const int output_masks_num = output_shape.channels;
// TODO Use libyuv for resizing instead.
std::vector<cv::Mat> segmented_mask_mats;
segmented_mask_mats.reserve(output_masks_num);
for (int i = 0; i < output_masks_num; ++i) {
segmented_mask_mats.push_back(
cv::Mat(input_shape.height, input_shape.width, cv_mat_type));
}
// Applies activation function.
const int tensor_size = input_shape.height * input_shape.width;
if (is_category_mask) {
for (int i = 0; i < tensor_size; ++i) {
absl::Span<const float> confidence_scores(
&tensors_buffer[i * input_shape.channels], input_shape.channels);
const int maximum_category_idx =
std::max_element(confidence_scores.begin(), confidence_scores.end()) -
confidence_scores.begin();
segmented_mask_mats[0].at<uint8_t>(
i / input_shape.width, i % input_shape.width) = maximum_category_idx;
}
} else { } else {
std::vector<float> activated_values(input_shape.channels); return ProcessForConfidenceMaskCpu(input_shape, output_shape,
absl::Span<float> activated_values_span(activated_values); options_.segmenter_options(),
for (int i = 0; i < tensor_size; ++i) { tensors_buffer);
activation_fn(
absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels],
input_shape.channels),
activated_values_span);
for (int j = 0; j < input_shape.channels; ++j) {
segmented_mask_mats[j].at<float>(
i / input_shape.width, i % input_shape.width) = activated_values[j];
}
}
} }
std::vector<Image> segmented_masks;
segmented_masks.reserve(output_masks_num);
// Resizes segmented masks to required output size.
for (int i = 0; i < segmented_mask_mats.size(); i++) {
// Pre-allocates ImageFrame memory to avoid copying from cv::Mat afterward.
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
is_category_mask ? ImageFormat::GRAY8 : ImageFormat::VEC32F1,
output_shape.width, output_shape.height, 1);
cv::Mat resized_mask_mat_view =
mediapipe::formats::MatView(image_frame_ptr.get());
cv::resize(segmented_mask_mats[i], resized_mask_mat_view,
resized_mask_mat_view.size(), 0, 0,
cv_mat_type == CV_8UC1 ? cv::INTER_NEAREST : cv::INTER_LINEAR);
segmented_masks.push_back(Image(image_frame_ptr));
}
return segmented_masks;
} }
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator); MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator);

View File

@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
SegmenterOptions::CONFIDENCE_MASK); SegmenterOptions::CONFIDENCE_MASK);
break; break;
} }
switch (options->activation) {
case ImageSegmenterOptions::Activation::NONE:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::NONE);
break;
case ImageSegmenterOptions::Activation::SIGMOID:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::SIGMOID);
break;
case ImageSegmenterOptions::Activation::SOFTMAX:
options_proto->mutable_segmenter_options()->set_activation(
SegmenterOptions::SOFTMAX);
break;
}
return options_proto; return options_proto;
} }

View File

@ -64,15 +64,6 @@ struct ImageSegmenterOptions {
OutputType output_type = OutputType::CATEGORY_MASK; OutputType output_type = OutputType::CATEGORY_MASK;
// The activation function used on the raw segmentation model output.
enum Activation {
NONE = 0, // No activation function is used.
SIGMOID = 1, // Assumes 1-channel input tensor.
SOFTMAX = 2, // Assumes multi-channel input tensor.
};
Activation activation = Activation::NONE;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h" #include "mediapipe/util/label_map_util.h"
@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter // Struct holding the different output streams produced by the image segmenter
// subgraph. // subgraph.
@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
const ImageSegmenterGraphOptions& segmenter_option, const ImageSegmenterGraphOptions& segmenter_option,
const core::ModelResources& model_resources, const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) { TensorsToSegmentationCalculatorOptions* options) {
*options->mutable_segmenter_options() = segmenter_option.segmenter_options(); // Set default activation function NONE
options->mutable_segmenter_options()->set_output_type(
segmenter_option.segmenter_options().output_type());
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
bool found_activation_in_metadata = false;
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
metadata_extractor->GetCustomMetadataList()->size() > 0) {
for (const auto& custom_metadata :
*metadata_extractor->GetCustomMetadataList()) {
if (custom_metadata->name()->str() == kSegmentationMetadataName) {
found_activation_in_metadata = true;
auto activation_fb =
GetImageSegmenterOptions(custom_metadata->data()->data())
->activation();
switch (activation_fb) {
case Activation_NONE:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::NONE);
break;
case Activation_SIGMOID:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::SIGMOID);
break;
case Activation_SOFTMAX:
options->mutable_segmenter_options()->set_activation(
SegmenterOptions::SOFTMAX);
break;
default:
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Invalid activation type found in CustomMetadata of "
"ImageSegmenterOptions type.");
}
}
}
}
if (!found_activation_in_metadata) {
LOG(WARNING)
<< "No activation type is found in model metadata. Use NONE for "
"ImageSegmenterGraph.";
}
const tflite::Model& model = *model_resources.GetTfLiteModel(); const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) { if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
*options->mutable_label_items(), *options->mutable_label_items(),
GetLabelItemsIfAny(*metadata_extractor, GetLabelItemsIfAny(*metadata_extractor,
@ -401,7 +443,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
} else { } else {
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
GetOutputTensor(model_resources)); GetOutputTensor(model_resources));
const int segmentation_streams_num = *output_tensor->shape()->rbegin(); int segmentation_streams_num = *output_tensor->shape()->rbegin();
for (int i = 0; i < segmentation_streams_num; ++i) { for (int i = 0; i < segmentation_streams_num; ++i) {
segmented_masks.push_back(Source<Image>( segmented_masks.push_back(Source<Image>(
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i])); tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));

View File

@ -62,6 +62,11 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite";
constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite"; constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite";
constexpr char kSelfieSegmentation[] = "selfie_segmentation.tflite";
constexpr char kSelfieSegmentationLandscape[] =
"selfie_segmentation_landscape.tflite";
constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite"; constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite";
constexpr float kGoldenMaskSimilarity = 0.98; constexpr float kGoldenMaskSimilarity = 0.98;
@ -90,13 +95,8 @@ cv::Mat PostProcessResultMask(const cv::Mat& mask) {
} }
Image GetSRGBImage(const std::string& image_path) { Image GetSRGBImage(const std::string& image_path) {
// TODO: fix test so RGB really is used and not BGR/BGRA.
// mediapipe/app/aimatter/segmentation/segmenter_test_common.cc
// golden masks are generated with BGR image. To align with the unittest of
// aimatter segmenter, here reads image as BGR as well (opencv reads image as
// BGR). Once the correctness of mediapipe tasks segmenter is verified, change
// the golden masks to be generated by RGB image.
cv::Mat image_mat = cv::imread(image_path); cv::Mat image_mat = cv::imread(image_path);
cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGB);
mediapipe::ImageFrame image_frame( mediapipe::ImageFrame image_frame(
mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows, mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows,
image_mat.step, image_mat.data, [image_mat](uint8_t[]) {}); image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
@ -435,6 +430,82 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 1);
MP_ASSERT_OK(segmenter->Close());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory,
"portrait_selfie_segmentation_expected_confidence_mask.jpg"),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat selfie_mask = mediapipe::formats::MatView(
confidence_masks[0].GetImageFrameSharedPtr().get());
EXPECT_THAT(selfie_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1);
MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread(
JoinPath("./", kTestDataDirectory,
"portrait_selfie_segmentation_expected_category_mask.jpg"),
cv::IMREAD_GRAYSCALE);
EXPECT_THAT(selfie_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
}
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
EXPECT_EQ(category_mask.size(), 1);
MP_ASSERT_OK(segmenter->Close());
cv::Mat selfie_mask = mediapipe::formats::MatView(
category_mask[0].GetImageFrameSharedPtr().get());
cv::Mat expected_mask = cv::imread(
JoinPath(
"./", kTestDataDirectory,
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg"),
cv::IMREAD_GRAYSCALE);
EXPECT_THAT(selfie_mask,
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
}
TEST_F(ImageModeTest, SucceedsHairSegmentation) { TEST_F(ImageModeTest, SucceedsHairSegmentation) {
Image image = Image image =
GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
@ -442,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));

View File

@ -0,0 +1,76 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Docs for Mediapipe Tasks Interactive Segmenter
# TODO: add doc link.
cc_library(
name = "interactive_segmenter",
srcs = ["interactive_segmenter.cc"],
hdrs = ["interactive_segmenter.h"],
deps = [
":interactive_segmenter_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:keypoint",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:render_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "interactive_segmenter_graph",
srcs = ["interactive_segmenter_graph.cc"],
deps = [
"@com_google_absl//absl/strings",
"//mediapipe/calculators/image:set_alpha_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/util:color_cc_proto",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:render_data_cc_proto",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
],
}),
alwayslink = 1,
)

View File

@ -0,0 +1,163 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kRoiTag[] = "ROI";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kRoiTag).SetName(kRoiStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
// ImageSegmenterOptions proto.
std::unique_ptr<ImageSegmenterGraphOptionsProto>
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
switch (options->output_type) {
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CATEGORY_MASK);
break;
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
options_proto->mutable_segmenter_options()->set_output_type(
SegmenterOptions::CONFIDENCE_MASK);
break;
}
return options_proto;
}
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
// is used in subgraph.
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
RenderData result;
switch (roi.format) {
case RegionOfInterest::UNSPECIFIED:
return absl::InvalidArgumentError(
"RegionOfInterest format not specified");
case RegionOfInterest::KEYPOINT:
RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
auto* point = annotation->mutable_point();
point->set_normalized(true);
point->set_x(roi.keypoint->x);
point->set_y(roi.keypoint->y);
return result;
}
return absl::UnimplementedError("Unrecognized format");
}
} // namespace
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
InteractiveSegmenter::Create(
std::unique_ptr<InteractiveSegmenterOptions> options) {
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
/*packets_callback=*/nullptr);
}
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kRoiStreamName,
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,136 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
// The options for configuring a mediapipe interactive segmenter task.
struct InteractiveSegmenterOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// The output type of segmentation results.
enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK = 0,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK = 1,
};
OutputType output_type = OutputType::CATEGORY_MASK;
};
// The Region-Of-Interest (ROI) to interact with.
struct RegionOfInterest {
enum Format {
UNSPECIFIED = 0, // Format not specified.
KEYPOINT = 1, // Using keypoint to represent ROI.
};
// Specifies the format used to specify the region-of-interest. Note that
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
// being returned.
Format format = Format::UNSPECIFIED;
// Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`.
std::optional<components::containers::NormalizedKeypoint> keypoint;
};
// Performs interactive segmentation on images.
//
// Users can represent user interaction through `RegionOfInterest`, which gives
// a hint to InteractiveSegmenter to perform segmentation focusing on the given
// region of interest.
//
// The API expects a TFLite model with mandatory TFLite Model Metadata.
//
// Input tensor:
// (kTfLiteUInt8/kTfLiteFloat32)
// - image input of size `[batch x height x width x channels]`.
// - batch inference is not supported (`batch` is required to be 1).
// - RGB inputs is supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// Output tensors:
// (kTfLiteUInt8/kTfLiteFloat32)
// - list of segmented masks.
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
// `channels`.
// - batch is always 1
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates an InteractiveSegmenter from the provided options. A non-default
// OpResolver can be specified in the BaseOptions of
// InteractiveSegmenterOptions, to support custom Ops of the segmentation
// model.
static absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> Create(
std::unique_ptr<InteractiveSegmenterOptions> options);
// Performs image segmentation on the provided single image.
//
// The image can be of any size with format RGB.
//
// The `roi` parameter is used to represent user's region of interest for
// segmentation.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image, const RegionOfInterest& roi,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Shuts down the InteractiveSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_

View File

@ -0,0 +1,198 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
using image_segmenter::proto::ImageSegmenterGraphOptions;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageCpuTag[] = "IMAGE_CPU";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kAlphaTag[] = "ALPHA";
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRoiTag[] = "ROI";
constexpr char kVideoTag[] = "VIDEO";
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
// in GpuBuffer format, otherwise using ImageFrame.
Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
Graph& graph) {
// TODO: Replace with efficient implementation.
const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag;
// Generates a blank canvas with same size as input image.
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
auto& flat_color_options =
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
// SetAlphaCalculator only takes 1st channel.
flat_color_options.mutable_color()->set_r(0);
image >> flat_color.In(kImageTag)[0];
auto blank_canvas = flat_color.Out(kImageTag)[0];
auto& from_mp_image = graph.AddNode("FromImageCalculator");
blank_canvas >> from_mp_image.In(kImageTag);
auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
blank_canvas_in_cpu_or_gpu >>
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
roi >> roi_to_alpha.In(0);
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
return alpha;
}
} // namespace
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// performs semantic segmentation given user's region-of-interest. Two kinds of
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
// retrieve segmented mask of only particular category/channel from
// SEGMENTATION, and users can also get all segmented masks from
// GROUPED_SEGMENTATION.
// - Accepts CPU input images and outputs segmented masks on CPU.
//
// Inputs:
// IMAGE - Image
// Image to perform segmentation on.
// ROI - RenderData proto
// Region of interest based on user interaction. Currently only support
// Point format, and Color has to be (255, 255, 255).
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection
// on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
// Segmented masks for individual category. Segmented mask of single
// category can be accessed by index based output stream.
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
// The output segmented masks grouped in a vector.
// IMAGE - mediapipe::Image
// The image that image segmenter runs on.
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
// input_stream: "IMAGE:image"
// input_stream: "ROI:region_of_interest"
// output_stream: "SEGMENTATION:segmented_masks"
// options {
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// segmenter_options {
// output_type: CONFIDENCE_MASK
// }
// }
// }
// }
class InteractiveSegmenterGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
Graph graph;
const auto& task_options = sc->Options<ImageSegmenterGraphOptions>();
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
Source<Image> image = graph[Input<Image>(kImageTag)];
Source<RenderData> roi = graph[Input<RenderData>(kRoiTag)];
Source<NormalizedRect> norm_rect =
graph[Input<NormalizedRect>(kNormRectTag)];
const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag;
const absl::string_view alpha_tag_with_suffix =
use_gpu ? kAlphaGpuTag : kAlphaTag;
auto& from_mp_image = graph.AddNode("FromImageCalculator");
image >> from_mp_image.In(kImageTag);
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix);
auto image_in_cpu_or_gpu_with_set_alpha =
set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
auto& to_mp_image = graph.AddNode("ToImageCalculator");
image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix);
auto image_with_set_alpha = to_mp_image.Out(kImageTag);
auto& image_segmenter = graph.AddNode(
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
image_segmenter.GetOptions<ImageSegmenterGraphOptions>() = task_options;
image_with_set_alpha >> image_segmenter.In(kImageTag);
norm_rect >> image_segmenter.In(kNormRectTag);
image_segmenter.Out(kSegmentationTag) >>
graph[Output<Image>(kSegmentationTag)];
image_segmenter.Out(kGroupedSegmentationTag) >>
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
};
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph);
// clang-format on
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,306 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
#include <memory>
#include <string>
#include "absl/flags/flag.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace {
using ::mediapipe::Image;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::NormalizedKeypoint;
using ::mediapipe::tasks::components::containers::RectF;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
// Golden mask for the dogs in cats_and_dogs.jpg.
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
constexpr float kGoldenMaskSimilarity = 0.97;
// Magnification factor used when creating the golden category masks to make
// them more human-friendly. Since interactive segmenter has only 2 categories,
// the golden mask uses 0 or 255 for each pixel.
constexpr int kGoldenMaskMagnificationFactor = 255;
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
// fair comparison.
cv::Mat PostProcessResultMask(const cv::Mat& mask) {
cv::Mat mask_float;
mask.convertTo(mask_float, CV_8UC1, 255);
mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f);
return mask_float;
}
double CalculateSum(const cv::Mat& m) {
double sum = 0.0;
cv::Scalar s = cv::sum(m);
for (int i = 0; i < m.channels(); ++i) {
sum += s.val[i];
}
return sum;
}
double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) {
cv::Mat intersection;
cv::multiply(m1, m2, intersection);
double intersection_value = CalculateSum(intersection);
double union_value =
CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value;
return union_value > 0.0 ? intersection_value / union_value : 0.0;
}
MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") {
cv::Mat actual_mask = PostProcessResultMask(arg);
return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols &&
CalculateSoftIOU(arg, expected_mask) > similarity_threshold;
}
MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
magnification_factor, "") {
if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) {
return false;
}
int consistent_pixels = 0;
const int num_pixels = expected_mask.rows * expected_mask.cols;
for (int i = 0; i < num_pixels; ++i) {
consistent_pixels +=
(arg.data[i] * magnification_factor == expected_mask.data[i]);
}
return static_cast<float>(consistent_pixels) / num_pixels >=
similarity_threshold;
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
public:
DeepLabOpResolverMissingOps() {
AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite::ops::builtin::Register_ADD());
}
DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete;
};
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->base_options.op_resolver =
absl::make_unique<DeepLabOpResolverMissingOps>();
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
// TODO: Make MediaPipe InferenceCalculator report the detailed
// interpreter errors (e.g., "Encountered unresolved custom op").
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
EXPECT_THAT(
segmenter_or.status().message(),
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
}
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
InteractiveSegmenter::Create(
std::make_unique<InteractiveSegmenterOptions>());
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
segmenter_or.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
struct InteractiveSegmenterTestParams {
std::string test_name;
RegionOfInterest::Format format;
NormalizedKeypoint roi;
std::string golden_mask_file;
float similarity_threshold;
};
using SucceedSegmentationWithRoi =
::testing::TestWithParam<InteractiveSegmenterTestParams>;
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(category_masks.size(), 1);
cv::Mat actual_mask = mediapipe::formats::MatView(
category_masks[0].GetImageFrameSharedPtr().get());
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
cv::IMREAD_GRAYSCALE);
EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
kGoldenMaskMagnificationFactor));
}
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const auto& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
segmenter->Segment(image, interaction_roi));
EXPECT_EQ(confidence_masks.size(), 2);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
cv::Mat actual_mask = mediapipe::formats::MatView(
confidence_masks[1].GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold));
}
INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::KEYPOINT,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::KEYPOINT,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; });
class ImageModeTest : public tflite_shims::testing::Test {};
// TODO: fix this unit test after image segmenter handled post
// processing correctly with rotated image.
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(
auto confidence_masks,
segmenter->Segment(image, interaction_roi, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 2);
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = RegionOfInterest::KEYPOINT;
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
options->output_type =
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
InteractiveSegmenter::Create(std::move(options)));
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results =
segmenter->Segment(image, interaction_roi, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
} // namespace
} // namespace interactive_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
} }
// TODO: remove this once activation is handled in metadata and grpah level.
segmenterOptionsBuilder.setActivation(
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(

View File

@ -73,6 +73,15 @@ py_library(
], ],
) )
py_library(
name = "keypoint",
srcs = ["keypoint.py"],
deps = [
"//mediapipe/framework/formats:location_data_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library( py_library(
name = "matrix_data", name = "matrix_data",
srcs = ["matrix_data.py"], srcs = ["matrix_data.py"],
@ -88,6 +97,7 @@ py_library(
deps = [ deps = [
":bounding_box", ":bounding_box",
":category", ":category",
":keypoint",
"//mediapipe/framework/formats:detection_py_pb2", "//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/framework/formats:location_data_py_pb2", "//mediapipe/framework/formats:location_data_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",

View File

@ -14,12 +14,13 @@
"""Detections data class.""" """Detections data class."""
import dataclasses import dataclasses
from typing import Any, List from typing import Any, List, Optional
from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import location_data_pb2 from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_DetectionListProto = detection_pb2.DetectionList _DetectionListProto = detection_pb2.DetectionList
@ -34,10 +35,12 @@ class Detection:
Attributes: Attributes:
bounding_box: A BoundingBox object. bounding_box: A BoundingBox object.
categories: A list of Category objects. categories: A list of Category objects.
keypoints: A list of NormalizedKeypoint objects.
""" """
bounding_box: bounding_box_module.BoundingBox bounding_box: bounding_box_module.BoundingBox
categories: List[category_module.Category] categories: List[category_module.Category]
keypoints: Optional[List[keypoint_module.NormalizedKeypoint]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionProto: def to_pb2(self) -> _DetectionProto:
@ -46,6 +49,8 @@ class Detection:
label_ids = [] label_ids = []
scores = [] scores = []
display_names = [] display_names = []
relative_keypoints = []
for category in self.categories: for category in self.categories:
scores.append(category.score) scores.append(category.score)
if category.index: if category.index:
@ -54,6 +59,20 @@ class Detection:
labels.append(category.category_name) labels.append(category.category_name)
if category.display_name: if category.display_name:
display_names.append(category.display_name) display_names.append(category.display_name)
if self.keypoints:
for keypoint in self.keypoints:
relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
if keypoint.x:
relative_keypoint_proto.x = keypoint.x
if keypoint.y:
relative_keypoint_proto.y = keypoint.y
if keypoint.label:
relative_keypoint_proto.keypoint_label = keypoint.label
if keypoint.score:
relative_keypoint_proto.score = keypoint.score
relative_keypoints.append(relative_keypoint_proto)
return _DetectionProto( return _DetectionProto(
label=labels, label=labels,
label_id=label_ids, label_id=label_ids,
@ -61,28 +80,52 @@ class Detection:
display_name=display_names, display_name=display_names,
location_data=_LocationDataProto( location_data=_LocationDataProto(
format=_LocationDataProto.Format.BOUNDING_BOX, format=_LocationDataProto.Format.BOUNDING_BOX,
bounding_box=self.bounding_box.to_pb2())) bounding_box=self.bounding_box.to_pb2(),
relative_keypoints=relative_keypoints,
),
)
@classmethod @classmethod
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
"""Creates a `Detection` object from the given protobuf object.""" """Creates a `Detection` object from the given protobuf object."""
categories = [] categories = []
keypoints = []
for idx, score in enumerate(pb2_obj.score): for idx, score in enumerate(pb2_obj.score):
categories.append( categories.append(
category_module.Category( category_module.Category(
score=score, score=score,
index=pb2_obj.label_id[idx] index=pb2_obj.label_id[idx]
if idx < len(pb2_obj.label_id) else None, if idx < len(pb2_obj.label_id)
else None,
category_name=pb2_obj.label[idx] category_name=pb2_obj.label[idx]
if idx < len(pb2_obj.label) else None, if idx < len(pb2_obj.label)
else None,
display_name=pb2_obj.display_name[idx] display_name=pb2_obj.display_name[idx]
if idx < len(pb2_obj.display_name) else None)) if idx < len(pb2_obj.display_name)
else None,
)
)
if pb2_obj.location_data.relative_keypoints:
for idx, elem in enumerate(pb2_obj.location_data.relative_keypoints):
keypoints.append(
keypoint_module.NormalizedKeypoint(
x=elem.x,
y=elem.y,
label=elem.keypoint_label,
score=elem.score,
)
)
return Detection( return Detection(
bounding_box=bounding_box_module.BoundingBox.create_from_pb2( bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
pb2_obj.location_data.bounding_box), pb2_obj.location_data.bounding_box
categories=categories) ),
categories=categories,
keypoints=keypoints,
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object. """Checks if this object is equal to the given object.

View File

@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keypoint data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_RelativeKeypointProto = location_data_pb2.LocationData.RelativeKeypoint
@dataclasses.dataclass
class NormalizedKeypoint:
"""A normalized keypoint.
Normalized keypoint represents a point in 2D space with x, y coordinates.
x and y are normalized to [0.0, 1.0] by the image width and height
respectively.
Attributes:
x: The x coordinates of the normalized keypoint.
y: The y coordinates of the normalized keypoint.
label: The optional label of the keypoint.
score: The score of the keypoint.
"""
x: Optional[float] = None
y: Optional[float] = None
label: Optional[str] = None
score: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _RelativeKeypointProto:
"""Generates a RelativeKeypoint protobuf object."""
return _RelativeKeypointProto(
x=self.x, y=self.y, keypoint_label=self.label, score=self.score
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _RelativeKeypointProto
) -> 'NormalizedKeypoint':
"""Creates a `NormalizedKeypoint` object from the given protobuf object."""
return NormalizedKeypoint(
x=pb2_obj.x,
y=pb2_obj.y,
label=pb2_obj.keypoint_label,
score=pb2_obj.score,
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, NormalizedKeypoint):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -92,6 +92,29 @@ py_test(
], ],
) )
py_test(
name = "face_detector_test",
srcs = ["face_detector_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/framework/formats:detection_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:bounding_box",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:detections",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:face_detector",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python",
],
)
py_test( py_test(
name = "hand_landmarker_test", name = "hand_landmarker_test",
srcs = ["hand_landmarker_test.py"], srcs = ["hand_landmarker_test.py"],

View File

@ -0,0 +1,523 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for face detector."""
import enum
import os
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from google.protobuf import text_format
from mediapipe.framework.formats import detection_pb2
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import detections as detections_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import face_detector
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
FaceDetectorResult = detections_module.DetectionResult
_BaseOptions = base_options_module.BaseOptions
_Category = category_module.Category
_BoundingBox = bounding_box_module.BoundingBox
_Detection = detections_module.Detection
_Image = image_module.Image
_FaceDetector = face_detector.FaceDetector
_FaceDetectorOptions = face_detector.FaceDetectorOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_SHORT_RANGE_BLAZE_FACE_MODEL = 'face_detection_short_range.tflite'
_PORTRAIT_IMAGE = 'portrait.jpg'
_PORTRAIT_EXPECTED_DETECTION = 'portrait_expected_detection.pbtxt'
_PORTRAIT_ROTATED_IMAGE = 'portrait_rotated.jpg'
_PORTRAIT_ROTATED_EXPECTED_DETECTION = (
'portrait_rotated_expected_detection.pbtxt'
)
_CAT_IMAGE = 'cat.jpg'
_KEYPOINT_ERROR_THRESHOLD = 1e-2
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _get_expected_face_detector_result(file_name: str) -> FaceDetectorResult:
face_detection_result_file_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, file_name)
)
with open(face_detection_result_file_path, 'rb') as f:
face_detection_proto = detection_pb2.Detection()
text_format.Parse(f.read(), face_detection_proto)
face_detection = detections_module.Detection.create_from_pb2(
face_detection_proto
)
return FaceDetectorResult(detections=[face_detection])
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class FaceDetectorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _PORTRAIT_IMAGE)
)
)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _SHORT_RANGE_BLAZE_FACE_MODEL)
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _FaceDetector.create_from_model_path(self.model_path) as detector:
self.assertIsInstance(detector, _FaceDetector)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceDetectorOptions(base_options=base_options)
with _FaceDetector.create_from_options(options) as detector:
self.assertIsInstance(detector, _FaceDetector)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite'
)
options = _FaceDetectorOptions(base_options=base_options)
_FaceDetector.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _FaceDetectorOptions(base_options=base_options)
detector = _FaceDetector.create_from_options(options)
self.assertIsInstance(detector, _FaceDetector)
def _expect_keypoints_correct(self, actual_keypoints, expected_keypoints):
self.assertLen(actual_keypoints, len(expected_keypoints))
for i in range(len(actual_keypoints)):
self.assertAlmostEqual(
actual_keypoints[i].x,
expected_keypoints[i].x,
delta=_KEYPOINT_ERROR_THRESHOLD,
)
self.assertAlmostEqual(
actual_keypoints[i].y,
expected_keypoints[i].y,
delta=_KEYPOINT_ERROR_THRESHOLD,
)
def _expect_face_detector_results_correct(
self, actual_results, expected_results
):
self.assertLen(actual_results.detections, len(expected_results.detections))
for i in range(len(actual_results.detections)):
actual_bbox = actual_results.detections[i].bounding_box
expected_bbox = expected_results.detections[i].bounding_box
self.assertEqual(actual_bbox, expected_bbox)
self.assertNotEmpty(actual_results.detections[i].keypoints)
self._expect_keypoints_correct(
actual_results.detections[i].keypoints,
expected_results.detections[i].keypoints,
)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
)
def test_detect(self, model_file_type, expected_detection_result_file):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceDetectorOptions(base_options=base_options)
detector = _FaceDetector.create_from_options(options)
# Performs face detection on the input.
detection_result = detector.detect(self.test_image)
# Comparing results.
expected_detection_result = _get_expected_face_detector_result(
expected_detection_result_file
)
self._expect_face_detector_results_correct(
detection_result, expected_detection_result
)
# Closes the detector explicitly when the detector is not used in
# a context.
detector.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
)
def test_detect_in_context(
self, model_file_type, expected_detection_result_file
):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceDetectorOptions(base_options=base_options)
with _FaceDetector.create_from_options(options) as detector:
# Performs face detection on the input.
detection_result = detector.detect(self.test_image)
# Comparing results.
expected_detection_result = _get_expected_face_detector_result(
expected_detection_result_file
)
self._expect_face_detector_results_correct(
detection_result, expected_detection_result
)
def test_detect_succeeds_with_rotated_image(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceDetectorOptions(base_options=base_options)
with _FaceDetector.create_from_options(options) as detector:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _PORTRAIT_ROTATED_IMAGE)
)
)
# Rotated input image.
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
# Performs face detection on the input.
detection_result = detector.detect(test_image, image_processing_options)
# Comparing results.
expected_detection_result = _get_expected_face_detector_result(
_PORTRAIT_ROTATED_EXPECTED_DETECTION
)
self._expect_face_detector_results_correct(
detection_result, expected_detection_result
)
def test_empty_detection_outputs(self):
# Load a test image with no faces.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path)
)
with _FaceDetector.create_from_options(options) as detector:
# Performs face detection on the input.
detection_result = detector.detect(test_image)
self.assertEmpty(detection_result.detections)
def test_missing_result_callback(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _FaceDetector.create_from_options(options) as unused_detector:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _FaceDetector.create_from_options(options) as unused_detector:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
detector.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _FaceDetector.create_from_options(options) as detector:
unused_result = detector.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(
ModelFileType.FILE_NAME,
_PORTRAIT_IMAGE,
0,
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
),
(
ModelFileType.FILE_CONTENT,
_PORTRAIT_IMAGE,
0,
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
),
(
ModelFileType.FILE_NAME,
_PORTRAIT_ROTATED_IMAGE,
-90,
_get_expected_face_detector_result(
_PORTRAIT_ROTATED_EXPECTED_DETECTION
),
),
(
ModelFileType.FILE_CONTENT,
_PORTRAIT_ROTATED_IMAGE,
-90,
_get_expected_face_detector_result(
_PORTRAIT_ROTATED_EXPECTED_DETECTION
),
),
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
)
def test_detect_for_video(
self,
model_file_type,
test_image_file_name,
rotation_degrees,
expected_detection_result,
):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceDetectorOptions(
base_options=base_options, running_mode=_RUNNING_MODE.VIDEO
)
with _FaceDetector.create_from_options(options) as detector:
for timestamp in range(0, 300, 30):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, test_image_file_name)
)
)
# Set the image processing options.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation_degrees
)
# Performs face detection on the input.
detection_result = detector.detect_for_video(
test_image, timestamp, image_processing_options
)
# Comparing results.
self._expect_face_detector_results_correct(
detection_result, expected_detection_result
)
def test_calling_detect_in_live_stream_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
detector.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceDetector.create_from_options(options) as detector:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
detector.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _FaceDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _FaceDetector.create_from_options(options) as detector:
detector.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
detector.detect_async(self.test_image, 0)
@parameterized.parameters(
(
ModelFileType.FILE_NAME,
_PORTRAIT_IMAGE,
0,
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
),
(
ModelFileType.FILE_CONTENT,
_PORTRAIT_IMAGE,
0,
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
),
(
ModelFileType.FILE_NAME,
_PORTRAIT_ROTATED_IMAGE,
-90,
_get_expected_face_detector_result(
_PORTRAIT_ROTATED_EXPECTED_DETECTION
),
),
(
ModelFileType.FILE_CONTENT,
_PORTRAIT_ROTATED_IMAGE,
-90,
_get_expected_face_detector_result(
_PORTRAIT_ROTATED_EXPECTED_DETECTION
),
),
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
)
def test_detect_async_calls(
self,
model_file_type,
test_image_file_name,
rotation_degrees,
expected_detection_result,
):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
observed_timestamp_ms = -1
def check_result(
result: FaceDetectorResult,
unused_output_image: _Image,
timestamp_ms: int,
):
self._expect_face_detector_results_correct(
result, expected_detection_result
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _FaceDetectorOptions(
base_options=base_options,
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result,
)
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, test_image_file_name)
)
)
with _FaceDetector.create_from_options(options) as detector:
for timestamp in range(0, 300, 30):
# Set the image processing options.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation_degrees
)
detector.detect_async(test_image, timestamp, image_processing_options)
if __name__ == '__main__':
absltest.main()

View File

@ -153,6 +153,26 @@ py_library(
], ],
) )
py_library(
name = "face_detector",
srcs = [
"face_detector.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:detections",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library( py_library(
name = "face_landmarker", name = "face_landmarker",
srcs = [ srcs = [

View File

@ -0,0 +1,332 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe face detector task."""
import dataclasses
from typing import Callable, Mapping, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.face_detector.proto import face_detector_graph_options_pb2
from mediapipe.tasks.python.components.containers import detections as detections_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
FaceDetectorResult = detections_module.DetectionResult
_BaseOptions = base_options_module.BaseOptions
_FaceDetectorGraphOptionsProto = (
face_detector_graph_options_pb2.FaceDetectorGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_DETECTIONS_OUT_STREAM_NAME = 'detections'
_DETECTIONS_TAG = 'DETECTIONS'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class FaceDetectorOptions:
"""Options for the face detector task.
Attributes:
base_options: Base options for the face detector task.
running_mode: The running mode of the task. Default to the image mode. Face
detector task has three running modes: 1) The image mode for detecting
faces on single image inputs. 2) The video mode for detecting faces on the
decoded frames of a video. 3) The live stream mode for detecting faces on
a live stream of input data, such as from camera.
min_detection_confidence: The minimum confidence score for the face
detection to be considered successful.
min_suppression_threshold: The minimum non-maximum-suppression threshold for
face detection to be considered overlapped.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
min_detection_confidence: Optional[float] = None
min_suppression_threshold: Optional[float] = None
result_callback: Optional[
Callable[
[detections_module.DetectionResult, image_module.Image, int], None
]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _FaceDetectorGraphOptionsProto:
"""Generates an FaceDetectorOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
return _FaceDetectorGraphOptionsProto(
base_options=base_options_proto,
min_detection_confidence=self.min_detection_confidence,
min_suppression_threshold=self.min_suppression_threshold,
)
class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs face detection on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'FaceDetector':
"""Creates an `FaceDetector` object from a TensorFlow Lite model and the default `FaceDetectorOptions`.
Note that the created `FaceDetector` instance is in image mode, for
detecting faces on single image inputs.
Args:
model_path: Path to the model.
Returns:
`FaceDetector` object that's created from the model file and the default
`FaceDetectorOptions`.
Raises:
ValueError: If failed to create `FaceDetector` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = FaceDetectorOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls, options: FaceDetectorOptions) -> 'FaceDetector':
"""Creates the `FaceDetector` object from face detector options.
Args:
options: Options for the face detector task.
Returns:
`FaceDetector` object that's created from `options`.
Raises:
ValueError: If failed to create `FaceDetector` object from
`FaceDetectorOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME]
options.result_callback(
FaceDetectorResult([]),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
detection_result = detections_module.DetectionResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(
detection_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def detect(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> FaceDetectorResult:
"""Performs face detection on the provided MediaPipe Image.
Only use this method when the FaceDetector is created with the image
running mode.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
A face detection result object that contains a list of face detections,
each detection has a bounding box that is expressed in the unrotated input
frame of reference coordinates system, i.e. in `[0,image_width) x [0,
image_height)`, which are the dimensions of the underlying image data.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If face detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
return FaceDetectorResult([])
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
return detections_module.DetectionResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> detections_module.DetectionResult:
"""Performs face detection on the provided video frames.
Only use this method when the FaceDetector is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
A face detection result object that contains a list of face detections,
each detection has a bounding box that is expressed in the unrotated input
frame of reference coordinates system, i.e. in `[0,image_width) x [0,
image_height)`, which are the dimensions of the underlying image data.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If face detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
return FaceDetectorResult([])
detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME]
)
return detections_module.DetectionResult(
[
detections_module.Detection.create_from_pb2(result)
for result in detection_proto_list
]
)
def detect_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform face detection.
Only use this method when the FaceDetector is created with the live stream
running mode. The input timestamps should be monotonically increasing for
adjacent calls of this method. This method will return immediately after the
input image is accepted. The results will be available via the
`result_callback` provided in the `FaceDetectorOptions`. The
`detect_async` method is designed to process live stream data such as camera
input. To lower the overall latency, face detector may drop the input
images if needed. In other words, it's not guaranteed to have output per
input image.
The `result_callback` provides:
- A face detection result object that contains a list of face detections,
each detection has a bounding box that is expressed in the unrotated
input frame of reference coordinates system,
i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions
of the underlying image data.
- The input image that the face detector runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the face
detector has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False
)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
})

View File

@ -31,6 +31,8 @@ mediapipe_files(srcs = [
"cat_rotated.jpg", "cat_rotated.jpg",
"cat_rotated_mask.jpg", "cat_rotated_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
"cats_and_dogs_mask_dog1.png",
"cats_and_dogs_mask_dog2.png",
"cats_and_dogs_no_resizing.jpg", "cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg", "cats_and_dogs_rotated.jpg",
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
@ -70,6 +72,9 @@ mediapipe_files(srcs = [
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg", "portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"portrait_selfie_segmentation_expected_category_mask.jpg",
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
"pose.jpg", "pose.jpg",
"pose_detection.tflite", "pose_detection.tflite",
"right_hands.jpg", "right_hands.jpg",
@ -113,6 +118,8 @@ filegroup(
"cat_rotated.jpg", "cat_rotated.jpg",
"cat_rotated_mask.jpg", "cat_rotated_mask.jpg",
"cats_and_dogs.jpg", "cats_and_dogs.jpg",
"cats_and_dogs_mask_dog1.png",
"cats_and_dogs_mask_dog2.png",
"cats_and_dogs_no_resizing.jpg", "cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg", "cats_and_dogs_rotated.jpg",
"fist.jpg", "fist.jpg",
@ -129,6 +136,9 @@ filegroup(
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg", "portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"portrait_selfie_segmentation_expected_category_mask.jpg",
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
"pose.jpg", "pose.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",

6
third_party/BUILD vendored
View File

@ -169,11 +169,7 @@ cmake_external(
"-lm", "-lm",
"-lpthread", "-lpthread",
"-lrt", "-lrt",
] + select({ ],
"//mediapipe:ios": ["-framework Cocoa"],
"//mediapipe:macos": ["-framework Cocoa"],
"//conditions:default": [],
}),
shared_libraries = select({ shared_libraries = select({
"@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES], "@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES],
# Only the shared objects listed here will be linked in the directory # Only the shared objects listed here will be linked in the directory

View File

@ -67,13 +67,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_BUILD", name = "com_google_mediapipe_BUILD",
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"],
)
http_file(
name = "com_google_mediapipe_BUILD_orig",
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
) )
http_file( http_file(
@ -136,6 +130,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"], urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"],
) )
http_file(
name = "com_google_mediapipe_cats_and_dogs_mask_dog1_png",
sha256 = "2ab37d56ba1e46e70b3ddbfe35dac51b18b597b76904c68d7d34c7c74c677d4c",
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog1.png?generation=1678840350058498"],
)
http_file(
name = "com_google_mediapipe_cats_and_dogs_mask_dog2_png",
sha256 = "2010850e2dd7f520fe53b9086d70913b6fb53b178cae15a373e5ee7ffb46824a",
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog2.png?generation=1678840352961684"],
)
http_file( http_file(
name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg", name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg",
sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58", sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58",
@ -886,6 +892,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"],
) )
http_file(
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg",
sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"],
)
http_file(
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_confidence_mask_jpg",
sha256 = "25b723e90608edaf6ed92f382da703dc904a59c87525b6d271e60d9eed7a90e9",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_confidence_mask.jpg?generation=1678606937358235"],
)
http_file(
name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg",
sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"],
)
http_file( http_file(
name = "com_google_mediapipe_pose_detection_tflite", name = "com_google_mediapipe_pose_detection_tflite",
sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c", sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c",
@ -1014,8 +1038,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg", name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg",
sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e", sha256 = "1a2a068287d8bcd4184492485b3dbb95a09b763f4653fd729d14a836147eb383",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1661875916766416"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1678606942616777"],
) )
http_file( http_file(
@ -1026,8 +1050,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_selfie_segm_144_256_3_expected_mask_jpg", name = "com_google_mediapipe_selfie_segm_144_256_3_expected_mask_jpg",
sha256 = "cfc699db9670585c04414d0d1a07b289a027ba99d6903d2219f897d34e2c9952", sha256 = "2de433b6e8adabec2aaf80135232db900903ead4f2811c0c9378a6792b2a68b5",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1661875922646736"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1678606945085676"],
) )
http_file( http_file(