Merge branch 'master' into face-landmarker-python
This commit is contained in:
commit
15d90bd325
17
LICENSE
17
LICENSE
|
@ -199,3 +199,20 @@
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
===========================================================================
|
||||
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
|
||||
===========================================================================
|
||||
/*
|
||||
* The authors of this software are Rob Pike and Ken Thompson.
|
||||
* Copyright (c) 2002 by Lucent Technologies.
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose without fee is hereby granted, provided that this entire notice
|
||||
* is included in all copies of any software which is or includes a copy
|
||||
* or modification of this software and in all copies of the supporting
|
||||
* documentation for such software.
|
||||
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||
*/
|
||||
|
|
|
@ -96,6 +96,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
] + select({
|
||||
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
||||
|
|
|
@ -42,3 +42,36 @@ cc_test(
|
|||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ngram_hash",
|
||||
srcs = ["ngram_hash.cc"],
|
||||
hdrs = ["ngram_hash.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "ngram_hash_test",
|
||||
srcs = ["ngram_hash_test.cc"],
|
||||
deps = [
|
||||
":ngram_hash",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
|
||||
namespace ngram_op {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::flexbuffers::GetRoot;
|
||||
using ::flexbuffers::Map;
|
||||
using ::flexbuffers::TypedVector;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::
|
||||
LowercaseUnicodeStr;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||
MurmurHash64WithSeed;
|
||||
using ::tflite::GetString;
|
||||
using ::tflite::StringRef;
|
||||
|
||||
constexpr int kInputMessage = 0;
|
||||
constexpr int kOutputLabel = 0;
|
||||
constexpr int kDefaultMaxSplits = 128;
|
||||
|
||||
// This op takes in a string, finds the character ngrams for it and then
|
||||
// maps each of these ngrams to an index using the specified vocabulary sizes.
|
||||
|
||||
// Input(s):
|
||||
// - input: Input string.
|
||||
// - seeds: Seed for the random number generator.
|
||||
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
|
||||
// be interpreted as generating unigrams, bigrams, and trigrams.
|
||||
// - vocab_sizes: Size of the vocabulary for each of the ngram features
|
||||
// respectively. The op would generate vocab ids to be less than or equal to
|
||||
// the vocab size. The index 0 implies an invalid ngram.
|
||||
// - max_splits: Maximum number of tokens in the output. If this is unset, the
|
||||
// limit is `kDefaultMaxSplits`.
|
||||
// - lower_case_input: If this is set to true, the input string would be
|
||||
// lower-cased before any processing.
|
||||
|
||||
// Output(s):
|
||||
// - output: A tensor of size [number of ngrams, number of tokens + 2],
|
||||
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
|
||||
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
|
||||
|
||||
// Helper class used for pre-processing the input.
|
||||
class NGramHashParams {
|
||||
public:
|
||||
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
|
||||
const std::vector<int>& vocab_sizes, int max_splits,
|
||||
bool lower_case_input)
|
||||
: seed_(seed),
|
||||
ngram_lengths_(ngram_lengths),
|
||||
vocab_sizes_(vocab_sizes),
|
||||
max_splits_(max_splits),
|
||||
lower_case_input_(lower_case_input) {}
|
||||
|
||||
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
|
||||
TfLiteContext* context) {
|
||||
if (input_t->bytes == 0) {
|
||||
context->ReportError(context, "Empty input not supported.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Do sanity checks on the input.
|
||||
if (ngram_lengths_.empty()) {
|
||||
context->ReportError(context, "`ngram_lengths` must be non-empty.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (vocab_sizes_.empty()) {
|
||||
context->ReportError(context, "`vocab_sizes` must be non-empty.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (ngram_lengths_.size() != vocab_sizes_.size()) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (max_splits_ <= 0) {
|
||||
context->ReportError(context, "`max_splits` must be > 0.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Obtain and tokenize the input.
|
||||
StringRef inputref = GetString(input_t, /*string_index=*/0);
|
||||
if (lower_case_input_) {
|
||||
std::string lower_cased_str;
|
||||
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
|
||||
|
||||
tokenized_output_ =
|
||||
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
|
||||
/*exclude_nonalphaspace_tokens=*/true);
|
||||
} else {
|
||||
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
|
||||
/*exclude_nonalphaspace_tokens=*/true);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
uint64_t GetSeed() const { return seed_; }
|
||||
|
||||
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
|
||||
|
||||
int GetNumNGrams() const { return ngram_lengths_.size(); }
|
||||
|
||||
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
|
||||
|
||||
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
|
||||
|
||||
const TokenizedOutput& GetTokenizedOutput() const {
|
||||
return tokenized_output_;
|
||||
}
|
||||
|
||||
TokenizedOutput tokenized_output_;
|
||||
|
||||
private:
|
||||
const uint64_t seed_;
|
||||
std::vector<int> ngram_lengths_;
|
||||
std::vector<int> vocab_sizes_;
|
||||
const int max_splits_;
|
||||
const bool lower_case_input_;
|
||||
};
|
||||
|
||||
// Convert the TypedVector into a regular std::vector.
|
||||
std::vector<int> GetIntVector(TypedVector typed_vec) {
|
||||
std::vector<int> vec(typed_vec.size());
|
||||
for (int j = 0; j < typed_vec.size(); j++) {
|
||||
vec[j] = typed_vec[j].AsInt32();
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
|
||||
const int max_unicode_length = params->GetNumTokens();
|
||||
const auto ngram_lengths = params->GetNGramLengths();
|
||||
const auto vocab_sizes = params->GetVocabSizes();
|
||||
const auto& tokenized_output = params->GetTokenizedOutput();
|
||||
const auto seed = params->GetSeed();
|
||||
|
||||
// Compute for each ngram.
|
||||
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
|
||||
const int vocab_size = vocab_sizes[ngram];
|
||||
const int ngram_length = ngram_lengths[ngram];
|
||||
|
||||
// Compute for each token within the input.
|
||||
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
|
||||
// Compute the number of bytes for the ngram starting at the given
|
||||
// token.
|
||||
int num_bytes = 0;
|
||||
for (int i = start;
|
||||
i < tokenized_output.tokens.size() && i < (start + ngram_length);
|
||||
i++) {
|
||||
num_bytes += tokenized_output.tokens[i].second;
|
||||
}
|
||||
|
||||
// Compute the hash for the ngram starting at the token.
|
||||
const auto str_hash = MurmurHash64WithSeed(
|
||||
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
|
||||
num_bytes, seed);
|
||||
|
||||
// Map the hash to an index in the vocab.
|
||||
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||
const Map& m = GetRoot(buffer_t, length).AsMap();
|
||||
|
||||
const uint64_t seed = m["seed"].AsUInt64();
|
||||
const std::vector<int> ngram_lengths =
|
||||
GetIntVector(m["ngram_lengths"].AsTypedVector());
|
||||
const std::vector<int> vocab_sizes =
|
||||
GetIntVector(m["vocab_sizes"].AsTypedVector());
|
||||
const int max_splits =
|
||||
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
|
||||
const bool lowercase_input =
|
||||
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
|
||||
|
||||
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
|
||||
lowercase_input);
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<NGramHashParams*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
if (IsDynamicTensor(output)) {
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||
output_size->data[0] = 1;
|
||||
output_size->data[1] = params->GetNumNGrams();
|
||||
output_size->data[2] = params->GetNumTokens();
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
} else {
|
||||
context->ReportError(context, "Output must by dynamic.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
GetNGramHashIndices(params, output->data.i32);
|
||||
} else {
|
||||
context->ReportError(context, "Output type must be Int32.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace ngram_op
|
||||
|
||||
TfLiteRegistration* Register_NGRAM_HASH() {
|
||||
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
|
||||
ngram_op::Resize, ngram_op::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
|
||||
TfLiteRegistration* Register_NGRAM_HASH();
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
|
@ -0,0 +1,313 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace {
|
||||
|
||||
using ::flexbuffers::Builder;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||
MurmurHash64WithSeed;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Message;
|
||||
|
||||
// Helper class for testing the op.
|
||||
class NGramHashModel : public SingleOpModel {
|
||||
public:
|
||||
explicit NGramHashModel(const uint64_t seed,
|
||||
const std::vector<int>& ngram_lengths,
|
||||
const std::vector<int>& vocab_sizes,
|
||||
const absl::optional<int> max_splits = std::nullopt) {
|
||||
// Setup the model inputs.
|
||||
Builder fbb;
|
||||
size_t start = fbb.StartMap();
|
||||
fbb.UInt("seed", seed);
|
||||
{
|
||||
size_t start = fbb.StartVector("ngram_lengths");
|
||||
for (const int& ngram_len : ngram_lengths) {
|
||||
fbb.Int(ngram_len);
|
||||
}
|
||||
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||
}
|
||||
{
|
||||
size_t start = fbb.StartVector("vocab_sizes");
|
||||
for (const int& vocab_size : vocab_sizes) {
|
||||
fbb.Int(vocab_size);
|
||||
}
|
||||
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||
}
|
||||
if (max_splits) {
|
||||
fbb.Int("max_splits", *max_splits);
|
||||
}
|
||||
fbb.EndMap(start);
|
||||
fbb.Finish();
|
||||
output_ = AddOutput({TensorType_INT32, {}});
|
||||
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
void SetupInputTensor(const std::string& input) {
|
||||
PopulateStringTensor(input_, {input});
|
||||
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
|
||||
<< "Cannot allocate tensors";
|
||||
}
|
||||
|
||||
void Invoke(const std::string& input) {
|
||||
SetupInputTensor(input);
|
||||
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeUnchecked(const std::string& input) {
|
||||
SetupInputTensor(input);
|
||||
return SingleOpModel::Invoke();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_ = AddInput(TensorType_STRING);
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||
// Checks that the op returns the expected value when the input is sane.
|
||||
// Also checks that when `max_splits` is not specified, the entire string is
|
||||
// tokenized.
|
||||
const uint64_t kSeed = 123;
|
||||
const std::vector<int> vocab_sizes({100, 200});
|
||||
std::vector<int> ngram_lengths({1, 2});
|
||||
const std::vector<std::string> testcase_inputs({
|
||||
"hi",
|
||||
"wow",
|
||||
"!",
|
||||
"HI",
|
||||
});
|
||||
|
||||
// A hash function that maps the given string to an index in the embedding
|
||||
// table denoted by `vocab_idx`.
|
||||
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||
const auto hash_value =
|
||||
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||
};
|
||||
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||
{{
|
||||
// Unigram & Bigram output for "hi".
|
||||
hash("^", 0),
|
||||
hash("h", 0),
|
||||
hash("i", 0),
|
||||
hash("$", 0),
|
||||
hash("^h", 1),
|
||||
hash("hi", 1),
|
||||
hash("i$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow".
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "!" (which will get replaced by " ").
|
||||
hash("^", 0),
|
||||
hash(" ", 0),
|
||||
hash("$", 0),
|
||||
hash("^ ", 1),
|
||||
hash(" $", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "HI" (which will get lower-cased).
|
||||
hash("^", 0),
|
||||
hash("h", 0),
|
||||
hash("i", 0),
|
||||
hash("$", 0),
|
||||
hash("^h", 1),
|
||||
hash("hi", 1),
|
||||
hash("i$", 1),
|
||||
hash("$", 1),
|
||||
}});
|
||||
|
||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||
const string& testcase_input = testcase_inputs[test_idx];
|
||||
m.Invoke(testcase_input);
|
||||
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||
<< testcase_input);
|
||||
EXPECT_THAT(m.GetOutput<int>(),
|
||||
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||
EXPECT_THAT(m.GetOutputShape(),
|
||||
ElementsAreArray(
|
||||
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
|
||||
// Checks that the op returns the expected value when the input is correct
|
||||
// when `max_splits` is specified.
|
||||
const uint64_t kSeed = 123;
|
||||
const std::vector<int> vocab_sizes({100, 200});
|
||||
std::vector<int> ngram_lengths({1, 2});
|
||||
|
||||
const std::string testcase_input = "wow";
|
||||
const std::vector<int> max_splits({2, 3, 4, 5, 6});
|
||||
|
||||
// A hash function that maps the given string to an index in the embedding
|
||||
// table denoted by `vocab_idx`.
|
||||
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||
const auto hash_value =
|
||||
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||
{{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 2.
|
||||
// We cannot include any of the actual tokens, since `max_splits`
|
||||
// only allows enough space for the delimiters.
|
||||
hash("^", 0),
|
||||
hash("$", 0),
|
||||
hash("^$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 3.
|
||||
// We can start to include some tokens from the input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 4.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("o$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 5.
|
||||
// We can include the full input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 6.
|
||||
// `max_splits` is more than the full input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
}});
|
||||
|
||||
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
|
||||
const int testcase_max_splits = max_splits[test_idx];
|
||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
|
||||
m.Invoke(testcase_input);
|
||||
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
|
||||
EXPECT_THAT(m.GetOutput<int>(),
|
||||
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||
EXPECT_THAT(
|
||||
m.GetOutputShape(),
|
||||
ElementsAreArray(
|
||||
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||
std::min(
|
||||
// Longest possible tokenization when using the entire
|
||||
// input.
|
||||
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
|
||||
// Longest possible string when the `max_splits` value
|
||||
// is < testcase_input.size() + 2 for padding.
|
||||
testcase_max_splits)}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, InvalidMaxSplitsValue) {
|
||||
// Check that the op errors out when given an invalid max splits value.
|
||||
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
|
||||
for (const int max_splits : invalid_max_splits) {
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
|
||||
{
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
|
||||
/*vocab_sizes=*/{1, 2});
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
{
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||
/*vocab_sizes=*/{1, 2, 3});
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "ngram_hash_ops_utils",
|
||||
srcs = [
|
||||
"ngram_hash_ops_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ngram_hash_ops_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "ngram_hash_ops_utils_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"ngram_hash_ops_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":ngram_hash_ops_utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||
|
||||
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||
bool exclude_nonalphaspace_tokens) {
|
||||
const std::string kPrefix = "^";
|
||||
const std::string kSuffix = "$";
|
||||
const std::string kReplacementToken = " ";
|
||||
|
||||
TokenizedOutput output;
|
||||
|
||||
size_t token_start = 0;
|
||||
output.str.reserve(len + 2);
|
||||
output.tokens.reserve(len + 2);
|
||||
|
||||
output.str.append(kPrefix);
|
||||
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
|
||||
token_start += kPrefix.size();
|
||||
|
||||
Rune token;
|
||||
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
|
||||
// Use the standard UTF-8 library to find the next token.
|
||||
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||
|
||||
// Stop processing, if we can't read any more tokens, or we have reached
|
||||
// maximum allowed tokens, allocating one token for the suffix.
|
||||
if (bytes_read == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
|
||||
// alphanumeric, replace it with a replacement token.
|
||||
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
|
||||
output.str.append(kReplacementToken);
|
||||
output.tokens.push_back(
|
||||
std::make_pair(token_start, kReplacementToken.size()));
|
||||
token_start += kReplacementToken.size();
|
||||
i += bytes_read;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Append the token in the output string, and note its position and the
|
||||
// number of bytes that token consumed.
|
||||
output.str.append(input_str + i, bytes_read);
|
||||
output.tokens.push_back(std::make_pair(token_start, bytes_read));
|
||||
token_start += bytes_read;
|
||||
i += bytes_read;
|
||||
}
|
||||
output.str.append(kSuffix);
|
||||
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
|
||||
token_start += kSuffix.size();
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||
std::string* output_str) {
|
||||
for (int i = 0; i < len;) {
|
||||
Rune token;
|
||||
|
||||
// Tokenize the given string, and get the appropriate lowercase token.
|
||||
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
|
||||
|
||||
// Write back the token to the output string.
|
||||
char token_buf[UTFmax];
|
||||
size_t bytes_to_write = utf_runetochar(token_buf, &token);
|
||||
output_str->append(token_buf, bytes_to_write);
|
||||
|
||||
i += bytes_read;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,56 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||
|
||||
struct TokenizedOutput {
|
||||
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
|
||||
std::string str;
|
||||
|
||||
// This vector contains pairs, where each pair has two members. The first
|
||||
// denoting the starting index of the token in the `str` string, and the
|
||||
// second denoting the length of that token in bytes.
|
||||
std::vector<std::pair<const size_t, const size_t>> tokens;
|
||||
};
|
||||
|
||||
// Tokenizes the given input string on Unicode token boundaries, with a maximum
|
||||
// of `max_tokens` tokens.
|
||||
//
|
||||
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
|
||||
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
|
||||
//
|
||||
// The method returns the output in the `TokenizedOutput` struct, which stores
|
||||
// both, the processed input string, and the indices and sizes of each token
|
||||
// within that string.
|
||||
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||
bool exclude_nonalphaspace_tokens);
|
||||
|
||||
// Converts the given unicode string (`input_str`) with the specified length
|
||||
// (`len`) to a lowercase string.
|
||||
//
|
||||
// The method populates the lowercased string in `output_str`.
|
||||
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||
std::string* output_str);
|
||||
|
||||
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
|
@ -0,0 +1,135 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::Values;
|
||||
|
||||
std::string ReconstructStringFromTokens(TokenizedOutput output) {
|
||||
std::string reconstructed_str;
|
||||
for (int i = 0; i < output.tokens.size(); i++) {
|
||||
reconstructed_str.append(
|
||||
output.str.c_str() + output.tokens[i].first,
|
||||
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
|
||||
}
|
||||
return reconstructed_str;
|
||||
}
|
||||
|
||||
struct TokenizeTestParams {
|
||||
std::string input_str;
|
||||
size_t max_tokens;
|
||||
bool exclude_nonalphaspace_tokens;
|
||||
std::string expected_output_str;
|
||||
};
|
||||
|
||||
class TokenizeParameterizedTest
|
||||
: public ::testing::Test,
|
||||
public testing::WithParamInterface<TokenizeTestParams> {};
|
||||
|
||||
TEST_P(TokenizeParameterizedTest, Tokenize) {
|
||||
// Checks that the Tokenize method returns the expected value.
|
||||
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
|
||||
const TokenizedOutput output = Tokenize(
|
||||
/*input_str=*/params.input_str.c_str(),
|
||||
/*len=*/params.input_str.size(),
|
||||
/*max_tokens=*/params.max_tokens,
|
||||
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
|
||||
|
||||
// The output string should have the necessary prefixes, and the "!" token
|
||||
// should have been replaced with a " ".
|
||||
EXPECT_EQ(output.str, params.expected_output_str);
|
||||
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TokenizeParameterizedTests, TokenizeParameterizedTest,
|
||||
Values(
|
||||
// Test including non-alphanumeric characters.
|
||||
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||
/*exclude_alphanonspace=*/false,
|
||||
/*expected_output_str=*/"^hi!$"}),
|
||||
// Test not including non-alphanumeric characters.
|
||||
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||
/*exclude_alphanonspace=*/true,
|
||||
/*expected_output_str=*/"^hi $"}),
|
||||
// Test with a maximum of 3 tokens.
|
||||
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
|
||||
/*exclude_alphanonspace=*/true,
|
||||
/*expected_output_str=*/"^h$"}),
|
||||
// Test with non-latin characters.
|
||||
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
|
||||
/*exclude_alphanonspace=*/true,
|
||||
/*expected_output_str=*/"^ありがと$"})));
|
||||
|
||||
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
|
||||
{
|
||||
// Check that the method is a no-op when the string is lowercase.
|
||||
std::string input_str = "hello";
|
||||
std::string output_str;
|
||||
LowercaseUnicodeStr(
|
||||
/*input_str=*/input_str.c_str(),
|
||||
/*len=*/input_str.size(),
|
||||
/*output_str=*/&output_str);
|
||||
|
||||
EXPECT_EQ(output_str, "hello");
|
||||
}
|
||||
{
|
||||
// Check that the method has uppercase characters.
|
||||
std::string input_str = "hElLo";
|
||||
std::string output_str;
|
||||
LowercaseUnicodeStr(
|
||||
/*input_str=*/input_str.c_str(),
|
||||
/*len=*/input_str.size(),
|
||||
/*output_str=*/&output_str);
|
||||
|
||||
EXPECT_EQ(output_str, "hello");
|
||||
}
|
||||
{
|
||||
// Check that the method works with non-latin scripts.
|
||||
// Cyrillic has the concept of cases, so it should change the input.
|
||||
std::string input_str = "БЙп";
|
||||
std::string output_str;
|
||||
LowercaseUnicodeStr(
|
||||
/*input_str=*/input_str.c_str(),
|
||||
/*len=*/input_str.size(),
|
||||
/*output_str=*/&output_str);
|
||||
|
||||
EXPECT_EQ(output_str, "бйп");
|
||||
}
|
||||
{
|
||||
// Check that the method works with non-latin scripts.
|
||||
// Japanese doesn't have the concept of cases, so it should not change.
|
||||
std::string input_str = "ありがと";
|
||||
std::string output_str;
|
||||
LowercaseUnicodeStr(
|
||||
/*input_str=*/input_str.c_str(),
|
||||
/*len=*/input_str.size(),
|
||||
/*output_str=*/&output_str);
|
||||
|
||||
EXPECT_EQ(output_str, "ありがと");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "utf",
|
||||
srcs = [
|
||||
"rune.c",
|
||||
"runetype.c",
|
||||
"runetypebody.h",
|
||||
],
|
||||
hdrs = ["utf.h"],
|
||||
)
|
|
@ -0,0 +1,233 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Forked from a library written by Rob Pike and Ken Thompson. Original
|
||||
// copyright message below.
|
||||
/*
|
||||
* The authors of this software are Rob Pike and Ken Thompson.
|
||||
* Copyright (c) 2002 by Lucent Technologies.
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose without fee is hereby granted, provided that this entire notice
|
||||
* is included in all copies of any software which is or includes a copy
|
||||
* or modification of this software and in all copies of the supporting
|
||||
* documentation for such software.
|
||||
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||
*/
|
||||
#include <stdarg.h>
|
||||
#include <string.h>
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||
|
||||
enum
|
||||
{
|
||||
Bit1 = 7,
|
||||
Bitx = 6,
|
||||
Bit2 = 5,
|
||||
Bit3 = 4,
|
||||
Bit4 = 3,
|
||||
Bit5 = 2,
|
||||
|
||||
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
|
||||
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
|
||||
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
|
||||
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
|
||||
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
|
||||
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
|
||||
|
||||
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
|
||||
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
|
||||
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
|
||||
Rune4 = (1<<(Bit4+3*Bitx))-1,
|
||||
/* 0001 1111 1111 1111 1111 1111 */
|
||||
|
||||
Maskx = (1<<Bitx)-1, /* 0011 1111 */
|
||||
Testx = Maskx ^ 0xFF, /* 1100 0000 */
|
||||
|
||||
Bad = Runeerror,
|
||||
};
|
||||
|
||||
/*
|
||||
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
|
||||
* This is a slower but "safe" version of the old chartorune
|
||||
* that works on strings that are not necessarily null-terminated.
|
||||
*
|
||||
* If you know for sure that your string is null-terminated,
|
||||
* chartorune will be a bit faster.
|
||||
*
|
||||
* It is guaranteed not to attempt to access "length"
|
||||
* past the incoming pointer. This is to avoid
|
||||
* possible access violations. If the string appears to be
|
||||
* well-formed but incomplete (i.e., to get the whole Rune
|
||||
* we'd need to read past str+length) then we'll set the Rune
|
||||
* to Bad and return 0.
|
||||
*
|
||||
* Note that if we have decoding problems for other
|
||||
* reasons, we return 1 instead of 0.
|
||||
*/
|
||||
int
|
||||
utf_charntorune(Rune *rune, const char *str, int length)
|
||||
{
|
||||
int c, c1, c2, c3;
|
||||
long l;
|
||||
|
||||
/* When we're not allowed to read anything */
|
||||
if(length <= 0) {
|
||||
goto badlen;
|
||||
}
|
||||
|
||||
/*
|
||||
* one character sequence (7-bit value)
|
||||
* 00000-0007F => T1
|
||||
*/
|
||||
c = *(uchar*)str;
|
||||
if(c < Tx) {
|
||||
*rune = c;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// If we can't read more than one character we must stop
|
||||
if(length <= 1) {
|
||||
goto badlen;
|
||||
}
|
||||
|
||||
/*
|
||||
* two character sequence (11-bit value)
|
||||
* 0080-07FF => T2 Tx
|
||||
*/
|
||||
c1 = *(uchar*)(str+1) ^ Tx;
|
||||
if(c1 & Testx)
|
||||
goto bad;
|
||||
if(c < T3) {
|
||||
if(c < T2)
|
||||
goto bad;
|
||||
l = ((c << Bitx) | c1) & Rune2;
|
||||
if(l <= Rune1)
|
||||
goto bad;
|
||||
*rune = l;
|
||||
return 2;
|
||||
}
|
||||
|
||||
// If we can't read more than two characters we must stop
|
||||
if(length <= 2) {
|
||||
goto badlen;
|
||||
}
|
||||
|
||||
/*
|
||||
* three character sequence (16-bit value)
|
||||
* 0800-FFFF => T3 Tx Tx
|
||||
*/
|
||||
c2 = *(uchar*)(str+2) ^ Tx;
|
||||
if(c2 & Testx)
|
||||
goto bad;
|
||||
if(c < T4) {
|
||||
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
|
||||
if(l <= Rune2)
|
||||
goto bad;
|
||||
*rune = l;
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (length <= 3)
|
||||
goto badlen;
|
||||
|
||||
/*
|
||||
* four character sequence (21-bit value)
|
||||
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||
*/
|
||||
c3 = *(uchar*)(str+3) ^ Tx;
|
||||
if (c3 & Testx)
|
||||
goto bad;
|
||||
if (c < T5) {
|
||||
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
|
||||
if (l <= Rune3)
|
||||
goto bad;
|
||||
if (l > Runemax)
|
||||
goto bad;
|
||||
*rune = l;
|
||||
return 4;
|
||||
}
|
||||
|
||||
// Support for 5-byte or longer UTF-8 would go here, but
|
||||
// since we don't have that, we'll just fall through to bad.
|
||||
|
||||
/*
|
||||
* bad decoding
|
||||
*/
|
||||
bad:
|
||||
*rune = Bad;
|
||||
return 1;
|
||||
badlen:
|
||||
*rune = Bad;
|
||||
return 0;
|
||||
|
||||
}
|
||||
|
||||
int
|
||||
utf_runetochar(char *str, const Rune *rune)
|
||||
{
|
||||
/* Runes are signed, so convert to unsigned for range check. */
|
||||
unsigned long c;
|
||||
|
||||
/*
|
||||
* one character sequence
|
||||
* 00000-0007F => 00-7F
|
||||
*/
|
||||
c = *rune;
|
||||
if(c <= Rune1) {
|
||||
str[0] = c;
|
||||
return 1;
|
||||
}
|
||||
|
||||
/*
|
||||
* two character sequence
|
||||
* 0080-07FF => T2 Tx
|
||||
*/
|
||||
if(c <= Rune2) {
|
||||
str[0] = T2 | (c >> 1*Bitx);
|
||||
str[1] = Tx | (c & Maskx);
|
||||
return 2;
|
||||
}
|
||||
|
||||
/*
|
||||
* If the Rune is out of range, convert it to the error rune.
|
||||
* Do this test here because the error rune encodes to three bytes.
|
||||
* Doing it earlier would duplicate work, since an out of range
|
||||
* Rune wouldn't have fit in one or two bytes.
|
||||
*/
|
||||
if (c > Runemax)
|
||||
c = Runeerror;
|
||||
|
||||
/*
|
||||
* three character sequence
|
||||
* 0800-FFFF => T3 Tx Tx
|
||||
*/
|
||||
if (c <= Rune3) {
|
||||
str[0] = T3 | (c >> 2*Bitx);
|
||||
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||
str[2] = Tx | (c & Maskx);
|
||||
return 3;
|
||||
}
|
||||
|
||||
/*
|
||||
* four character sequence (21-bit value)
|
||||
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||
*/
|
||||
str[0] = T4 | (c >> 3*Bitx);
|
||||
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
|
||||
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||
str[3] = Tx | (c & Maskx);
|
||||
return 4;
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Forked from a library written by Rob Pike and Ken Thompson. Original
|
||||
// copyright message below.
|
||||
/*
|
||||
* The authors of this software are Rob Pike and Ken Thompson.
|
||||
* Copyright (c) 2002 by Lucent Technologies.
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose without fee is hereby granted, provided that this entire notice
|
||||
* is included in all copies of any software which is or includes a copy
|
||||
* or modification of this software and in all copies of the supporting
|
||||
* documentation for such software.
|
||||
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||
*/
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||
|
||||
static
|
||||
Rune*
|
||||
rbsearch(Rune c, Rune *t, int n, int ne)
|
||||
{
|
||||
Rune *p;
|
||||
int m;
|
||||
|
||||
while(n > 1) {
|
||||
m = n >> 1;
|
||||
p = t + m*ne;
|
||||
if(c >= p[0]) {
|
||||
t = p;
|
||||
n = n-m;
|
||||
} else
|
||||
n = m;
|
||||
}
|
||||
if(n && c >= t[0])
|
||||
return t;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define RUNETYPEBODY
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h"
|
|
@ -0,0 +1,212 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef RUNETYPEBODY
|
||||
|
||||
static Rune __isalphar[] = {
|
||||
0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6,
|
||||
0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374,
|
||||
0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1,
|
||||
0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556,
|
||||
0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a,
|
||||
0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef,
|
||||
0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea,
|
||||
0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac,
|
||||
0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f,
|
||||
0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0,
|
||||
0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1,
|
||||
0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30,
|
||||
0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c,
|
||||
0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8,
|
||||
0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1,
|
||||
0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30,
|
||||
0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61,
|
||||
0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a,
|
||||
0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9,
|
||||
0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33,
|
||||
0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c,
|
||||
0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9,
|
||||
0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10,
|
||||
0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96,
|
||||
0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30,
|
||||
0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88,
|
||||
0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab,
|
||||
0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf,
|
||||
0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a,
|
||||
0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070,
|
||||
0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248,
|
||||
0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288,
|
||||
0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be,
|
||||
0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315,
|
||||
0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c,
|
||||
0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c,
|
||||
0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c,
|
||||
0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8,
|
||||
0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974,
|
||||
0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54,
|
||||
0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf,
|
||||
0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d,
|
||||
0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf,
|
||||
0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d,
|
||||
0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc,
|
||||
0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb,
|
||||
0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c,
|
||||
0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139,
|
||||
0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e,
|
||||
0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3,
|
||||
0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6,
|
||||
0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6,
|
||||
0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006,
|
||||
0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f,
|
||||
0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e,
|
||||
0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc,
|
||||
0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f,
|
||||
0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5,
|
||||
0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793,
|
||||
0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a,
|
||||
0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7,
|
||||
0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2,
|
||||
0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76,
|
||||
0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd,
|
||||
0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e,
|
||||
0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2,
|
||||
0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d,
|
||||
0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28,
|
||||
0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44,
|
||||
0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7,
|
||||
0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a,
|
||||
0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf,
|
||||
0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026,
|
||||
0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d,
|
||||
0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e,
|
||||
0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3,
|
||||
0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835,
|
||||
0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939,
|
||||
0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17,
|
||||
0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55,
|
||||
0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af,
|
||||
0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4,
|
||||
0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38,
|
||||
0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454,
|
||||
0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac,
|
||||
0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a,
|
||||
0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e,
|
||||
0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0,
|
||||
0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734,
|
||||
0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8,
|
||||
0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f,
|
||||
0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f,
|
||||
0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72,
|
||||
0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b,
|
||||
0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6,
|
||||
0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d,
|
||||
};
|
||||
|
||||
static Rune __isalphas[] = {
|
||||
0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559,
|
||||
0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828,
|
||||
0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd,
|
||||
0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd,
|
||||
0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5,
|
||||
0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7,
|
||||
0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59,
|
||||
0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115,
|
||||
0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f,
|
||||
0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e,
|
||||
0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24,
|
||||
0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54,
|
||||
0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e,
|
||||
};
|
||||
|
||||
int utf_isalpharune(Rune c) {
|
||||
Rune *p;
|
||||
|
||||
p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2);
|
||||
if (p && c >= p[0] && c <= p[1]) return 1;
|
||||
p = rbsearch(c, __isalphas, nelem(__isalphas), 1);
|
||||
if (p && c == p[0]) return 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Rune __tolowerr[] = {
|
||||
0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608,
|
||||
0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613,
|
||||
0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608,
|
||||
0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608,
|
||||
0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568,
|
||||
0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568,
|
||||
0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568,
|
||||
0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568,
|
||||
0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568,
|
||||
0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464,
|
||||
0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592,
|
||||
0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761,
|
||||
0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616,
|
||||
};
|
||||
|
||||
static Rune __tolowerp[] = {
|
||||
0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577,
|
||||
0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577,
|
||||
0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577,
|
||||
0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577,
|
||||
0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577,
|
||||
0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577,
|
||||
0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568,
|
||||
0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577,
|
||||
0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577,
|
||||
0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577,
|
||||
0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577,
|
||||
};
|
||||
|
||||
static Rune __tolowers[] = {
|
||||
0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786,
|
||||
0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577,
|
||||
0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779,
|
||||
0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787,
|
||||
0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789,
|
||||
0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577,
|
||||
0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577,
|
||||
0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578,
|
||||
0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578,
|
||||
0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577,
|
||||
0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371,
|
||||
0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577,
|
||||
0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577,
|
||||
0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584,
|
||||
0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577,
|
||||
0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840,
|
||||
0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569,
|
||||
0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314,
|
||||
0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833,
|
||||
0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827,
|
||||
0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577,
|
||||
0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577,
|
||||
0xa78d, 1006296, 0xa7aa, 1006268,
|
||||
};
|
||||
|
||||
Rune utf_tolowerrune(Rune c) {
|
||||
Rune *p;
|
||||
|
||||
p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3);
|
||||
if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576;
|
||||
p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3);
|
||||
if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1))
|
||||
return c + p[2] - 1048576;
|
||||
p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2);
|
||||
if (p && c == p[0]) return c + p[1] - 1048576;
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,98 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Fork of several UTF utils originally written by Rob Pike and Ken Thompson.
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Code-point values in Unicode 4.0 are 21 bits wide.
|
||||
typedef signed int Rune;
|
||||
|
||||
#define uchar _utfuchar
|
||||
|
||||
typedef unsigned char uchar;
|
||||
|
||||
#define nelem(x) (sizeof(x) / sizeof((x)[0]))
|
||||
|
||||
enum {
|
||||
UTFmax = 4, // maximum bytes per rune
|
||||
Runeerror = 0xFFFD, // decoding error in UTF
|
||||
Runemax = 0x10FFFF, // maximum rune value
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* rune routines
|
||||
*/
|
||||
|
||||
/*
|
||||
* These routines were written by Rob Pike and Ken Thompson
|
||||
* and first appeared in Plan 9.
|
||||
* SEE ALSO
|
||||
* utf (7)
|
||||
* tcs (1)
|
||||
*/
|
||||
|
||||
// utf_runetochar copies (encodes) one rune, pointed to by r, to at most
|
||||
// UTFmax bytes starting at s and returns the number of bytes generated.
|
||||
|
||||
int utf_runetochar(char* s, const Rune* r);
|
||||
|
||||
// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to
|
||||
// one rune, pointed to by `rune`, accesss at most `length` bytes of `str`, and
|
||||
// returns the number of bytes consumed.
|
||||
// If the UTF sequence is incomplete within n bytes,
|
||||
// utf_charntorune will set *r to Runeerror and return 0. If it is complete
|
||||
// but not in UTF format, it will set *r to Runeerror and return 1.
|
||||
//
|
||||
// Added 2004-09-24 by Wei-Hwa Huang
|
||||
|
||||
int utf_charntorune(Rune* rune, const char* str, int length);
|
||||
|
||||
// Unicode defines some characters as letters and
|
||||
// specifies three cases: upper, lower, and title. Mappings among the
|
||||
// cases are also defined, although they are not exhaustive: some
|
||||
// upper case letters have no lower case mapping, and so on. Unicode
|
||||
// also defines several character properties, a subset of which are
|
||||
// checked by these routines. These routines are based on Unicode
|
||||
// version 3.0.0.
|
||||
//
|
||||
// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false
|
||||
// and 1 for true.
|
||||
//
|
||||
// utf_tolowerrune is the Unicode case mapping. It returns the character
|
||||
// unchanged if it has no defined mapping.
|
||||
|
||||
Rune utf_tolowerrune(Rune r);
|
||||
|
||||
// utf_isalpharune tests for Unicode letters; this includes ideographs in
|
||||
// addition to alphabetic characters.
|
||||
|
||||
int utf_isalpharune(Rune r);
|
||||
|
||||
// (The comments in this file were copied from the manpage files rune.3,
|
||||
// isalpharune.3, and runestrcat.3. Some formatting changes were also made
|
||||
// to conform to Google style. /JRM 11/11/05)
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
|
|
@ -80,6 +80,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
@ -79,6 +80,133 @@ void Sigmoid(absl::Span<const float> values,
|
|||
[](float value) { return 1. / (1 + std::exp(-value)); });
|
||||
}
|
||||
|
||||
std::vector<Image> ProcessForCategoryMaskCpu(const Shape& input_shape,
|
||||
const Shape& output_shape,
|
||||
const SegmenterOptions& options,
|
||||
const float* tensors_buffer) {
|
||||
cv::Mat resized_tensors_mat;
|
||||
cv::Mat tensors_mat_view(
|
||||
input_shape.height, input_shape.width, CV_32FC(input_shape.channels),
|
||||
reinterpret_cast<void*>(const_cast<float*>(tensors_buffer)));
|
||||
if (output_shape.height == input_shape.height &&
|
||||
output_shape.width == input_shape.width) {
|
||||
resized_tensors_mat = tensors_mat_view;
|
||||
} else {
|
||||
// Resize input tensors to output size.
|
||||
// TOOD(b/273633027) Use an efficient way to find values for category mask
|
||||
// instead of resizing the whole tensor .
|
||||
cv::resize(tensors_mat_view, resized_tensors_mat,
|
||||
{output_shape.width, output_shape.height}, 0, 0,
|
||||
cv::INTER_LINEAR);
|
||||
}
|
||||
|
||||
// Category mask Image.
|
||||
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
|
||||
ImageFormat::GRAY8, output_shape.width, output_shape.height, 1);
|
||||
Image category_mask(image_frame_ptr);
|
||||
|
||||
// Fill in the maximum category in the category mask image.
|
||||
cv::Mat category_mask_mat_view =
|
||||
mediapipe::formats::MatView(image_frame_ptr.get());
|
||||
const int input_channels = input_shape.channels;
|
||||
category_mask_mat_view.forEach<uint8_t>(
|
||||
[&resized_tensors_mat, &input_channels, &options](uint8_t& pixel,
|
||||
const int position[]) {
|
||||
float* tensors_buffer =
|
||||
resized_tensors_mat.ptr<float>(position[0], position[1]);
|
||||
absl::Span<float> confidence_scores(tensors_buffer, input_channels);
|
||||
// Only process the activation function if it is SIGMOID. If NONE,
|
||||
// we do nothing for activation, If SOFTMAX, it is required
|
||||
// to have input_channels > 1, and for input_channels > 1, we don't need
|
||||
// activation to find the maximum value.
|
||||
if (options.activation() == SegmenterOptions::SIGMOID) {
|
||||
Sigmoid(confidence_scores, confidence_scores);
|
||||
}
|
||||
if (input_channels == 1) {
|
||||
// if the input tensor is a single mask, it is assumed to be a binary
|
||||
// foreground segmentation mask. For such a mask, we make foreground
|
||||
// category 1, and background category 0.
|
||||
pixel = static_cast<uint8_t>(*tensors_buffer > 0.5f);
|
||||
} else {
|
||||
const int maximum_category_idx =
|
||||
std::max_element(confidence_scores.begin(),
|
||||
confidence_scores.end()) -
|
||||
confidence_scores.begin();
|
||||
pixel = maximum_category_idx;
|
||||
}
|
||||
});
|
||||
return {category_mask};
|
||||
}
|
||||
|
||||
std::vector<Image> ProcessForConfidenceMaskCpu(const Shape& input_shape,
|
||||
const Shape& output_shape,
|
||||
const SegmenterOptions& options,
|
||||
const float* tensors_buffer) {
|
||||
std::function<void(absl::Span<const float> values,
|
||||
absl::Span<float> activated_values)>
|
||||
activation_fn;
|
||||
switch (options.activation()) {
|
||||
case SegmenterOptions::SIGMOID:
|
||||
activation_fn = &Sigmoid;
|
||||
break;
|
||||
case SegmenterOptions::SOFTMAX:
|
||||
activation_fn = &StableSoftmax;
|
||||
break;
|
||||
case SegmenterOptions::NONE:
|
||||
// Just copying for NONE activation.
|
||||
activation_fn = [](absl::Span<const float> values,
|
||||
absl::Span<float> activated_values) {
|
||||
std::copy(values.begin(), values.end(), activated_values.begin());
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
// TODO Use libyuv for resizing instead.
|
||||
std::vector<Image> confidence_masks;
|
||||
std::vector<cv::Mat> confidence_mask_mats;
|
||||
confidence_masks.reserve(input_shape.channels);
|
||||
confidence_mask_mats.reserve(input_shape.channels);
|
||||
for (int i = 0; i < input_shape.channels; ++i) {
|
||||
confidence_masks.push_back(Image(std::make_shared<ImageFrame>(
|
||||
ImageFormat::VEC32F1, input_shape.width, input_shape.height, 1)));
|
||||
confidence_mask_mats.push_back(mediapipe::formats::MatView(
|
||||
confidence_masks.back().GetImageFrameSharedPtr().get()));
|
||||
}
|
||||
|
||||
// Applies activation function.
|
||||
const int tensor_size = input_shape.height * input_shape.width;
|
||||
std::vector<float> activated_values(input_shape.channels);
|
||||
absl::Span<float> activated_values_span(activated_values);
|
||||
for (int i = 0; i < tensor_size; ++i) {
|
||||
activation_fn(absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels],
|
||||
input_shape.channels),
|
||||
activated_values_span);
|
||||
for (int j = 0; j < input_shape.channels; ++j) {
|
||||
confidence_mask_mats[j].at<float>(
|
||||
i / input_shape.width, i % input_shape.width) = activated_values[j];
|
||||
}
|
||||
}
|
||||
if (output_shape.height == input_shape.height &&
|
||||
output_shape.width == input_shape.width) {
|
||||
return confidence_masks;
|
||||
}
|
||||
std::vector<Image> resized_confidence_masks;
|
||||
resized_confidence_masks.reserve(confidence_mask_mats.size());
|
||||
// Resizes segmented masks to required output size.
|
||||
for (int i = 0; i < confidence_mask_mats.size(); i++) {
|
||||
// Pre-allocates ImageFrame memory to avoid copying from cv::Mat
|
||||
// afterward.
|
||||
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
|
||||
ImageFormat::VEC32F1, output_shape.width, output_shape.height, 1);
|
||||
cv::Mat resized_mask_mat_view =
|
||||
mediapipe::formats::MatView(image_frame_ptr.get());
|
||||
cv::resize(confidence_mask_mats[i], resized_mask_mat_view,
|
||||
resized_mask_mat_view.size(), 0, 0, cv::INTER_LINEAR);
|
||||
resized_confidence_masks.push_back(Image(image_frame_ptr));
|
||||
}
|
||||
return resized_confidence_masks;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Converts Tensors from a vector of Tensor to Segmentation.
|
||||
|
@ -222,82 +350,17 @@ absl::Status TensorsToSegmentationCalculator::Process(
|
|||
std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResultCpu(
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const float* tensors_buffer) {
|
||||
std::function<void(absl::Span<const float> values,
|
||||
absl::Span<float> activated_values)>
|
||||
activation_fn;
|
||||
switch (options_.segmenter_options().activation()) {
|
||||
case SegmenterOptions::SIGMOID:
|
||||
activation_fn = &Sigmoid;
|
||||
break;
|
||||
case SegmenterOptions::SOFTMAX:
|
||||
activation_fn = &StableSoftmax;
|
||||
break;
|
||||
case SegmenterOptions::NONE:
|
||||
// Just copying for NONE activation.
|
||||
activation_fn = [](absl::Span<const float> values,
|
||||
absl::Span<float> activated_values) {
|
||||
std::copy(values.begin(), values.end(), activated_values.begin());
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
const bool is_category_mask = options_.segmenter_options().output_type() ==
|
||||
SegmenterOptions::CATEGORY_MASK;
|
||||
const int cv_mat_type = is_category_mask ? CV_8UC1 : CV_32FC1;
|
||||
const int output_masks_num = output_shape.channels;
|
||||
|
||||
// TODO Use libyuv for resizing instead.
|
||||
std::vector<cv::Mat> segmented_mask_mats;
|
||||
segmented_mask_mats.reserve(output_masks_num);
|
||||
for (int i = 0; i < output_masks_num; ++i) {
|
||||
segmented_mask_mats.push_back(
|
||||
cv::Mat(input_shape.height, input_shape.width, cv_mat_type));
|
||||
}
|
||||
|
||||
// Applies activation function.
|
||||
const int tensor_size = input_shape.height * input_shape.width;
|
||||
if (is_category_mask) {
|
||||
for (int i = 0; i < tensor_size; ++i) {
|
||||
absl::Span<const float> confidence_scores(
|
||||
&tensors_buffer[i * input_shape.channels], input_shape.channels);
|
||||
const int maximum_category_idx =
|
||||
std::max_element(confidence_scores.begin(), confidence_scores.end()) -
|
||||
confidence_scores.begin();
|
||||
segmented_mask_mats[0].at<uint8_t>(
|
||||
i / input_shape.width, i % input_shape.width) = maximum_category_idx;
|
||||
}
|
||||
if (options_.segmenter_options().output_type() ==
|
||||
SegmenterOptions::CATEGORY_MASK) {
|
||||
return ProcessForCategoryMaskCpu(input_shape, output_shape,
|
||||
options_.segmenter_options(),
|
||||
tensors_buffer);
|
||||
} else {
|
||||
std::vector<float> activated_values(input_shape.channels);
|
||||
absl::Span<float> activated_values_span(activated_values);
|
||||
for (int i = 0; i < tensor_size; ++i) {
|
||||
activation_fn(
|
||||
absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels],
|
||||
input_shape.channels),
|
||||
activated_values_span);
|
||||
for (int j = 0; j < input_shape.channels; ++j) {
|
||||
segmented_mask_mats[j].at<float>(
|
||||
i / input_shape.width, i % input_shape.width) = activated_values[j];
|
||||
return ProcessForConfidenceMaskCpu(input_shape, output_shape,
|
||||
options_.segmenter_options(),
|
||||
tensors_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Image> segmented_masks;
|
||||
segmented_masks.reserve(output_masks_num);
|
||||
// Resizes segmented masks to required output size.
|
||||
for (int i = 0; i < segmented_mask_mats.size(); i++) {
|
||||
// Pre-allocates ImageFrame memory to avoid copying from cv::Mat afterward.
|
||||
ImageFrameSharedPtr image_frame_ptr = std::make_shared<ImageFrame>(
|
||||
is_category_mask ? ImageFormat::GRAY8 : ImageFormat::VEC32F1,
|
||||
output_shape.width, output_shape.height, 1);
|
||||
cv::Mat resized_mask_mat_view =
|
||||
mediapipe::formats::MatView(image_frame_ptr.get());
|
||||
cv::resize(segmented_mask_mats[i], resized_mask_mat_view,
|
||||
resized_mask_mat_view.size(), 0, 0,
|
||||
cv_mat_type == CV_8UC1 ? cv::INTER_NEAREST : cv::INTER_LINEAR);
|
||||
segmented_masks.push_back(Image(image_frame_ptr));
|
||||
}
|
||||
return segmented_masks;
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator);
|
||||
|
||||
|
|
|
@ -101,20 +101,6 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
|||
SegmenterOptions::CONFIDENCE_MASK);
|
||||
break;
|
||||
}
|
||||
switch (options->activation) {
|
||||
case ImageSegmenterOptions::Activation::NONE:
|
||||
options_proto->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::NONE);
|
||||
break;
|
||||
case ImageSegmenterOptions::Activation::SIGMOID:
|
||||
options_proto->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::SIGMOID);
|
||||
break;
|
||||
case ImageSegmenterOptions::Activation::SOFTMAX:
|
||||
options_proto->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::SOFTMAX);
|
||||
break;
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,15 +64,6 @@ struct ImageSegmenterOptions {
|
|||
|
||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||
|
||||
// The activation function used on the raw segmentation model output.
|
||||
enum Activation {
|
||||
NONE = 0, // No activation function is used.
|
||||
SIGMOID = 1, // Assumes 1-channel input tensor.
|
||||
SOFTMAX = 2, // Assumes multi-channel input tensor.
|
||||
};
|
||||
|
||||
Activation activation = Activation::NONE;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
|
|
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/label_map_util.h"
|
||||
|
@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
|||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||
|
||||
// Struct holding the different output streams produced by the image segmenter
|
||||
// subgraph.
|
||||
|
@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
|||
const ImageSegmenterGraphOptions& segmenter_option,
|
||||
const core::ModelResources& model_resources,
|
||||
TensorsToSegmentationCalculatorOptions* options) {
|
||||
*options->mutable_segmenter_options() = segmenter_option.segmenter_options();
|
||||
// Set default activation function NONE
|
||||
options->mutable_segmenter_options()->set_output_type(
|
||||
segmenter_option.segmenter_options().output_type());
|
||||
options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE);
|
||||
// Find the custom metadata of ImageSegmenterOptions type in model metadata.
|
||||
const auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||
bool found_activation_in_metadata = false;
|
||||
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
|
||||
metadata_extractor->GetCustomMetadataList()->size() > 0) {
|
||||
for (const auto& custom_metadata :
|
||||
*metadata_extractor->GetCustomMetadataList()) {
|
||||
if (custom_metadata->name()->str() == kSegmentationMetadataName) {
|
||||
found_activation_in_metadata = true;
|
||||
auto activation_fb =
|
||||
GetImageSegmenterOptions(custom_metadata->data()->data())
|
||||
->activation();
|
||||
switch (activation_fb) {
|
||||
case Activation_NONE:
|
||||
options->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::NONE);
|
||||
break;
|
||||
case Activation_SIGMOID:
|
||||
options->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::SIGMOID);
|
||||
break;
|
||||
case Activation_SOFTMAX:
|
||||
options->mutable_segmenter_options()->set_activation(
|
||||
SegmenterOptions::SOFTMAX);
|
||||
break;
|
||||
default:
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Invalid activation type found in CustomMetadata of "
|
||||
"ImageSegmenterOptions type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_activation_in_metadata) {
|
||||
LOG(WARNING)
|
||||
<< "No activation type is found in model metadata. Use NONE for "
|
||||
"ImageSegmenterGraph.";
|
||||
}
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
|
@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
|||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
model_resources.GetMetadataExtractor();
|
||||
ASSIGN_OR_RETURN(
|
||||
*options->mutable_label_items(),
|
||||
GetLabelItemsIfAny(*metadata_extractor,
|
||||
|
@ -401,7 +443,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
} else {
|
||||
ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor,
|
||||
GetOutputTensor(model_resources));
|
||||
const int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||
int segmentation_streams_num = *output_tensor->shape()->rbegin();
|
||||
for (int i = 0; i < segmentation_streams_num; ++i) {
|
||||
segmented_masks.push_back(Source<Image>(
|
||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||
|
|
|
@ -62,6 +62,11 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite";
|
|||
|
||||
constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite";
|
||||
|
||||
constexpr char kSelfieSegmentation[] = "selfie_segmentation.tflite";
|
||||
|
||||
constexpr char kSelfieSegmentationLandscape[] =
|
||||
"selfie_segmentation_landscape.tflite";
|
||||
|
||||
constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite";
|
||||
|
||||
constexpr float kGoldenMaskSimilarity = 0.98;
|
||||
|
@ -90,13 +95,8 @@ cv::Mat PostProcessResultMask(const cv::Mat& mask) {
|
|||
}
|
||||
|
||||
Image GetSRGBImage(const std::string& image_path) {
|
||||
// TODO: fix test so RGB really is used and not BGR/BGRA.
|
||||
// mediapipe/app/aimatter/segmentation/segmenter_test_common.cc
|
||||
// golden masks are generated with BGR image. To align with the unittest of
|
||||
// aimatter segmenter, here reads image as BGR as well (opencv reads image as
|
||||
// BGR). Once the correctness of mediapipe tasks segmenter is verified, change
|
||||
// the golden masks to be generated by RGB image.
|
||||
cv::Mat image_mat = cv::imread(image_path);
|
||||
cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGB);
|
||||
mediapipe::ImageFrame image_frame(
|
||||
mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows,
|
||||
image_mat.step, image_mat.data, [image_mat](uint8_t[]) {});
|
||||
|
@ -304,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
|
@ -333,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
|
@ -364,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
|
@ -388,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
|
@ -416,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::NONE;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
|
@ -435,6 +430,82 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) {
|
||||
Image image =
|
||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
EXPECT_EQ(confidence_masks.size(), 1);
|
||||
MP_ASSERT_OK(segmenter->Close());
|
||||
|
||||
cv::Mat expected_mask = cv::imread(
|
||||
JoinPath("./", kTestDataDirectory,
|
||||
"portrait_selfie_segmentation_expected_confidence_mask.jpg"),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
cv::Mat expected_mask_float;
|
||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||
|
||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||
confidence_masks[0].GetImageFrameSharedPtr().get());
|
||||
EXPECT_THAT(selfie_mask,
|
||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) {
|
||||
Image image =
|
||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentation);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
|
||||
EXPECT_EQ(category_mask.size(), 1);
|
||||
MP_ASSERT_OK(segmenter->Close());
|
||||
|
||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||
category_mask[0].GetImageFrameSharedPtr().get());
|
||||
cv::Mat expected_mask = cv::imread(
|
||||
JoinPath("./", kTestDataDirectory,
|
||||
"portrait_selfie_segmentation_expected_category_mask.jpg"),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
EXPECT_THAT(selfie_mask,
|
||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) {
|
||||
Image image =
|
||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image));
|
||||
EXPECT_EQ(category_mask.size(), 1);
|
||||
MP_ASSERT_OK(segmenter->Close());
|
||||
|
||||
cv::Mat selfie_mask = mediapipe::formats::MatView(
|
||||
category_mask[0].GetImageFrameSharedPtr().get());
|
||||
cv::Mat expected_mask = cv::imread(
|
||||
JoinPath(
|
||||
"./", kTestDataDirectory,
|
||||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg"),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
EXPECT_THAT(selfie_mask,
|
||||
SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||
Image image =
|
||||
GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg"));
|
||||
|
@ -442,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
|
|
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
76
mediapipe/tasks/cc/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Docs for Mediapipe Tasks Interactive Segmenter
|
||||
# TODO: add doc link.
|
||||
cc_library(
|
||||
name = "interactive_segmenter",
|
||||
srcs = ["interactive_segmenter.cc"],
|
||||
hdrs = ["interactive_segmenter.h"],
|
||||
deps = [
|
||||
":interactive_segmenter_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:keypoint",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interactive_segmenter_graph",
|
||||
srcs = ["interactive_segmenter_graph.cc"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator",
|
||||
"//mediapipe/calculators/util:flat_color_image_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:render_data_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
|
||||
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -0,0 +1,163 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
namespace {
|
||||
|
||||
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kRoiStreamName[] = "roi_in";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
|
||||
options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kRoiTag).SetName(kRoiStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
graph.Out(kGroupedSegmentationTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
graph.In(kRoiTag) >> task_subgraph.In(kRoiTag);
|
||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing InteractiveSegmenterOptions struct to the internal
|
||||
// ImageSegmenterOptions proto.
|
||||
std::unique_ptr<ImageSegmenterGraphOptionsProto>
|
||||
ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
switch (options->output_type) {
|
||||
case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CATEGORY_MASK);
|
||||
break;
|
||||
case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK:
|
||||
options_proto->mutable_segmenter_options()->set_output_type(
|
||||
SegmenterOptions::CONFIDENCE_MASK);
|
||||
break;
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
// Converts the user-facing RegionOfInterest struct to the RenderData proto that
|
||||
// is used in subgraph.
|
||||
absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||
RenderData result;
|
||||
switch (roi.format) {
|
||||
case RegionOfInterest::UNSPECIFIED:
|
||||
return absl::InvalidArgumentError(
|
||||
"RegionOfInterest format not specified");
|
||||
case RegionOfInterest::KEYPOINT:
|
||||
RET_CHECK(roi.keypoint.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
auto* point = annotation->mutable_point();
|
||||
point->set_normalized(true);
|
||||
point->set_x(roi.keypoint->x);
|
||||
point->set_y(roi.keypoint->y);
|
||||
return result;
|
||||
}
|
||||
return absl::UnimplementedError("Unrecognized format");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>>
|
||||
InteractiveSegmenter::Create(
|
||||
std::unique_ptr<InteractiveSegmenterOptions> options) {
|
||||
auto options_proto = ConvertImageSegmenterOptionsToProto(options.get());
|
||||
return core::VisionTaskApiFactory::Create<InteractiveSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), core::RunningMode::IMAGE,
|
||||
/*packets_callback=*/nullptr);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> InteractiveSegmenter::Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||
{kRoiStreamName,
|
||||
mediapipe::MakePacket<RenderData>(std::move(roi_as_render_data))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,136 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
|
||||
// The options for configuring a mediapipe interactive segmenter task.
|
||||
struct InteractiveSegmenterOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// The output type of segmentation results.
|
||||
enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK = 0,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK = 1,
|
||||
};
|
||||
|
||||
OutputType output_type = OutputType::CATEGORY_MASK;
|
||||
};
|
||||
|
||||
// The Region-Of-Interest (ROI) to interact with.
|
||||
struct RegionOfInterest {
|
||||
enum Format {
|
||||
UNSPECIFIED = 0, // Format not specified.
|
||||
KEYPOINT = 1, // Using keypoint to represent ROI.
|
||||
};
|
||||
|
||||
// Specifies the format used to specify the region-of-interest. Note that
|
||||
// using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status
|
||||
// being returned.
|
||||
Format format = Format::UNSPECIFIED;
|
||||
|
||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||
// `format` is `KEYPOINT`.
|
||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||
};
|
||||
|
||||
// Performs interactive segmentation on images.
|
||||
//
|
||||
// Users can represent user interaction through `RegionOfInterest`, which gives
|
||||
// a hint to InteractiveSegmenter to perform segmentation focusing on the given
|
||||
// region of interest.
|
||||
//
|
||||
// The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||
//
|
||||
// Input tensor:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - image input of size `[batch x height x width x channels]`.
|
||||
// - batch inference is not supported (`batch` is required to be 1).
|
||||
// - RGB inputs is supported (`channels` is required to be 3).
|
||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
// attached to the metadata for input normalization.
|
||||
// Output tensors:
|
||||
// (kTfLiteUInt8/kTfLiteFloat32)
|
||||
// - list of segmented masks.
|
||||
// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
// `channels`.
|
||||
// - batch is always 1
|
||||
class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
// Creates an InteractiveSegmenter from the provided options. A non-default
|
||||
// OpResolver can be specified in the BaseOptions of
|
||||
// InteractiveSegmenterOptions, to support custom Ops of the segmentation
|
||||
// model.
|
||||
static absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> Create(
|
||||
std::unique_ptr<InteractiveSegmenterOptions> options);
|
||||
|
||||
// Performs image segmentation on the provided single image.
|
||||
//
|
||||
// The image can be of any size with format RGB.
|
||||
//
|
||||
// The `roi` parameter is used to represent user's region of interest for
|
||||
// segmentation.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||
mediapipe::Image image, const RegionOfInterest& roi,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Shuts down the InteractiveSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_
|
|
@ -0,0 +1,198 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
|
||||
namespace {
|
||||
|
||||
using image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
|
||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageCpuTag[] = "IMAGE_CPU";
|
||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||
constexpr char kAlphaTag[] = "ALPHA";
|
||||
constexpr char kAlphaGpuTag[] = "ALPHA_GPU";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kRoiTag[] = "ROI";
|
||||
constexpr char kVideoTag[] = "VIDEO";
|
||||
|
||||
// Updates the graph to return `roi` stream which has same dimension as
|
||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||
// in GpuBuffer format, otherwise using ImageFrame.
|
||||
Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||
Graph& graph) {
|
||||
// TODO: Replace with efficient implementation.
|
||||
const absl::string_view image_tag_with_suffix =
|
||||
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||
|
||||
// Generates a blank canvas with same size as input image.
|
||||
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
||||
auto& flat_color_options =
|
||||
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
|
||||
// SetAlphaCalculator only takes 1st channel.
|
||||
flat_color_options.mutable_color()->set_r(0);
|
||||
image >> flat_color.In(kImageTag)[0];
|
||||
auto blank_canvas = flat_color.Out(kImageTag)[0];
|
||||
|
||||
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||
blank_canvas >> from_mp_image.In(kImageTag);
|
||||
auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||
|
||||
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
|
||||
blank_canvas_in_cpu_or_gpu >>
|
||||
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||
roi >> roi_to_alpha.In(0);
|
||||
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||
|
||||
return alpha;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||
// performs semantic segmentation given user's region-of-interest. Two kinds of
|
||||
// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can
|
||||
// retrieve segmented mask of only particular category/channel from
|
||||
// SEGMENTATION, and users can also get all segmented masks from
|
||||
// GROUPED_SEGMENTATION.
|
||||
// - Accepts CPU input images and outputs segmented masks on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform segmentation on.
|
||||
// ROI - RenderData proto
|
||||
// Region of interest based on user interaction. Currently only support
|
||||
// Point format, and Color has to be (255, 255, 255).
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
//
|
||||
// Outputs:
|
||||
// SEGMENTATION - mediapipe::Image @Multiple
|
||||
// Segmented masks for individual category. Segmented mask of single
|
||||
// category can be accessed by index based output stream.
|
||||
// GROUPED_SEGMENTATION - std::vector<mediapipe::Image>
|
||||
// The output segmented masks grouped in a vector.
|
||||
// IMAGE - mediapipe::Image
|
||||
// The image that image segmenter runs on.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "ROI:region_of_interest"
|
||||
// output_stream: "SEGMENTATION:segmented_masks"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// segmenter_options {
|
||||
// output_type: CONFIDENCE_MASK
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
const auto& task_options = sc->Options<ImageSegmenterGraphOptions>();
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
|
||||
Source<Image> image = graph[Input<Image>(kImageTag)];
|
||||
Source<RenderData> roi = graph[Input<RenderData>(kRoiTag)];
|
||||
Source<NormalizedRect> norm_rect =
|
||||
graph[Input<NormalizedRect>(kNormRectTag)];
|
||||
const absl::string_view image_tag_with_suffix =
|
||||
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||
const absl::string_view alpha_tag_with_suffix =
|
||||
use_gpu ? kAlphaGpuTag : kAlphaTag;
|
||||
|
||||
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||
image >> from_mp_image.In(kImageTag);
|
||||
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||
|
||||
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
|
||||
|
||||
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
||||
image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||
alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix);
|
||||
auto image_in_cpu_or_gpu_with_set_alpha =
|
||||
set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||
|
||||
auto& to_mp_image = graph.AddNode("ToImageCalculator");
|
||||
image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix);
|
||||
auto image_with_set_alpha = to_mp_image.Out(kImageTag);
|
||||
|
||||
auto& image_segmenter = graph.AddNode(
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph");
|
||||
image_segmenter.GetOptions<ImageSegmenterGraphOptions>() = task_options;
|
||||
image_with_set_alpha >> image_segmenter.In(kImageTag);
|
||||
norm_rect >> image_segmenter.In(kNormRectTag);
|
||||
|
||||
image_segmenter.Out(kSegmentationTag) >>
|
||||
graph[Output<Image>(kSegmentationTag)];
|
||||
image_segmenter.Out(kGroupedSegmentationTag) >>
|
||||
graph[Output<std::vector<Image>>(kGroupedSegmentationTag)];
|
||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
};
|
||||
|
||||
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph);
|
||||
// clang-format on
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,306 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace interactive_segmenter {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::NormalizedKeypoint;
|
||||
using ::mediapipe::tasks::components::containers::RectF;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite";
|
||||
constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg";
|
||||
// Golden mask for the dogs in cats_and_dogs.jpg.
|
||||
constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png";
|
||||
constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png";
|
||||
|
||||
constexpr float kGoldenMaskSimilarity = 0.97;
|
||||
|
||||
// Magnification factor used when creating the golden category masks to make
|
||||
// them more human-friendly. Since interactive segmenter has only 2 categories,
|
||||
// the golden mask uses 0 or 255 for each pixel.
|
||||
constexpr int kGoldenMaskMagnificationFactor = 255;
|
||||
|
||||
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||
// fair comparison.
|
||||
cv::Mat PostProcessResultMask(const cv::Mat& mask) {
|
||||
cv::Mat mask_float;
|
||||
mask.convertTo(mask_float, CV_8UC1, 255);
|
||||
mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f);
|
||||
return mask_float;
|
||||
}
|
||||
|
||||
double CalculateSum(const cv::Mat& m) {
|
||||
double sum = 0.0;
|
||||
cv::Scalar s = cv::sum(m);
|
||||
for (int i = 0; i < m.channels(); ++i) {
|
||||
sum += s.val[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) {
|
||||
cv::Mat intersection;
|
||||
cv::multiply(m1, m2, intersection);
|
||||
double intersection_value = CalculateSum(intersection);
|
||||
double union_value =
|
||||
CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value;
|
||||
return union_value > 0.0 ? intersection_value / union_value : 0.0;
|
||||
}
|
||||
|
||||
MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") {
|
||||
cv::Mat actual_mask = PostProcessResultMask(arg);
|
||||
return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols &&
|
||||
CalculateSoftIOU(arg, expected_mask) > similarity_threshold;
|
||||
}
|
||||
|
||||
MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
||||
magnification_factor, "") {
|
||||
if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) {
|
||||
return false;
|
||||
}
|
||||
int consistent_pixels = 0;
|
||||
const int num_pixels = expected_mask.rows * expected_mask.cols;
|
||||
for (int i = 0; i < num_pixels; ++i) {
|
||||
consistent_pixels +=
|
||||
(arg.data[i] * magnification_factor == expected_mask.data[i]);
|
||||
}
|
||||
return static_cast<float>(consistent_pixels) / num_pixels >=
|
||||
similarity_threshold;
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
|
||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||
public:
|
||||
DeepLabOpResolverMissingOps() {
|
||||
AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||
::tflite::ops::builtin::Register_ADD());
|
||||
}
|
||||
|
||||
DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete;
|
||||
};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) {
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<DeepLabOpResolverMissingOps>();
|
||||
auto segmenter_or = InteractiveSegmenter::Create(std::move(options));
|
||||
// TODO: Make MediaPipe InferenceCalculator report the detailed
|
||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk"));
|
||||
}
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
absl::StatusOr<std::unique_ptr<InteractiveSegmenter>> segmenter_or =
|
||||
InteractiveSegmenter::Create(
|
||||
std::make_unique<InteractiveSegmenterOptions>());
|
||||
|
||||
EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
struct InteractiveSegmenterTestParams {
|
||||
std::string test_name;
|
||||
RegionOfInterest::Format format;
|
||||
NormalizedKeypoint roi;
|
||||
std::string golden_mask_file;
|
||||
float similarity_threshold;
|
||||
};
|
||||
|
||||
using SucceedSegmentationWithRoi =
|
||||
::testing::TestWithParam<InteractiveSegmenterTestParams>;
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto category_masks,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(category_masks.size(), 1);
|
||||
|
||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||
category_masks[0].GetImageFrameSharedPtr().get());
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
EXPECT_THAT(actual_mask,
|
||||
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
|
||||
kGoldenMaskMagnificationFactor));
|
||||
}
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||
const auto& params = GetParam();
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||
segmenter->Segment(image, interaction_roi));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
cv::Mat expected_mask_float;
|
||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||
|
||||
cv::Mat actual_mask = mediapipe::formats::MatView(
|
||||
confidence_masks[1].GetImageFrameSharedPtr().get());
|
||||
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
||||
params.similarity_threshold));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||
{{"PointToDog1", RegionOfInterest::KEYPOINT,
|
||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"PointToDog2", RegionOfInterest::KEYPOINT,
|
||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||
kGoldenMaskSimilarity}}),
|
||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||
info) { return info.param.test_name; });
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
// TODO: fix this unit test after image segmenter handled post
|
||||
// processing correctly with rotated image.
|
||||
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto confidence_masks,
|
||||
segmenter->Segment(image, interaction_roi, image_processing_options));
|
||||
EXPECT_EQ(confidence_masks.size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = RegionOfInterest::KEYPOINT;
|
||||
interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66};
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
options->output_type =
|
||||
InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<InteractiveSegmenter> segmenter,
|
||||
InteractiveSegmenter::Create(std::move(options)));
|
||||
RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results =
|
||||
segmenter->Segment(image, interaction_roi, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace interactive_segmenter
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -641,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
// TODO: remove this once activation is handled in metadata and grpah level.
|
||||
segmenterOptionsBuilder.setActivation(
|
||||
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
|
||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
|
|
|
@ -73,6 +73,15 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "keypoint",
|
||||
srcs = ["keypoint.py"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:location_data_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "matrix_data",
|
||||
srcs = ["matrix_data.py"],
|
||||
|
@ -88,6 +97,7 @@ py_library(
|
|||
deps = [
|
||||
":bounding_box",
|
||||
":category",
|
||||
":keypoint",
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/framework/formats:location_data_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
|
|
|
@ -14,12 +14,13 @@
|
|||
"""Detections data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.framework.formats import location_data_pb2
|
||||
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_DetectionListProto = detection_pb2.DetectionList
|
||||
|
@ -34,10 +35,12 @@ class Detection:
|
|||
Attributes:
|
||||
bounding_box: A BoundingBox object.
|
||||
categories: A list of Category objects.
|
||||
keypoints: A list of NormalizedKeypoint objects.
|
||||
"""
|
||||
|
||||
bounding_box: bounding_box_module.BoundingBox
|
||||
categories: List[category_module.Category]
|
||||
keypoints: Optional[List[keypoint_module.NormalizedKeypoint]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _DetectionProto:
|
||||
|
@ -46,6 +49,8 @@ class Detection:
|
|||
label_ids = []
|
||||
scores = []
|
||||
display_names = []
|
||||
relative_keypoints = []
|
||||
|
||||
for category in self.categories:
|
||||
scores.append(category.score)
|
||||
if category.index:
|
||||
|
@ -54,6 +59,20 @@ class Detection:
|
|||
labels.append(category.category_name)
|
||||
if category.display_name:
|
||||
display_names.append(category.display_name)
|
||||
|
||||
if self.keypoints:
|
||||
for keypoint in self.keypoints:
|
||||
relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
|
||||
if keypoint.x:
|
||||
relative_keypoint_proto.x = keypoint.x
|
||||
if keypoint.y:
|
||||
relative_keypoint_proto.y = keypoint.y
|
||||
if keypoint.label:
|
||||
relative_keypoint_proto.keypoint_label = keypoint.label
|
||||
if keypoint.score:
|
||||
relative_keypoint_proto.score = keypoint.score
|
||||
relative_keypoints.append(relative_keypoint_proto)
|
||||
|
||||
return _DetectionProto(
|
||||
label=labels,
|
||||
label_id=label_ids,
|
||||
|
@ -61,28 +80,52 @@ class Detection:
|
|||
display_name=display_names,
|
||||
location_data=_LocationDataProto(
|
||||
format=_LocationDataProto.Format.BOUNDING_BOX,
|
||||
bounding_box=self.bounding_box.to_pb2()))
|
||||
bounding_box=self.bounding_box.to_pb2(),
|
||||
relative_keypoints=relative_keypoints,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
|
||||
"""Creates a `Detection` object from the given protobuf object."""
|
||||
categories = []
|
||||
keypoints = []
|
||||
|
||||
for idx, score in enumerate(pb2_obj.score):
|
||||
categories.append(
|
||||
category_module.Category(
|
||||
score=score,
|
||||
index=pb2_obj.label_id[idx]
|
||||
if idx < len(pb2_obj.label_id) else None,
|
||||
if idx < len(pb2_obj.label_id)
|
||||
else None,
|
||||
category_name=pb2_obj.label[idx]
|
||||
if idx < len(pb2_obj.label) else None,
|
||||
if idx < len(pb2_obj.label)
|
||||
else None,
|
||||
display_name=pb2_obj.display_name[idx]
|
||||
if idx < len(pb2_obj.display_name) else None))
|
||||
if idx < len(pb2_obj.display_name)
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
if pb2_obj.location_data.relative_keypoints:
|
||||
for idx, elem in enumerate(pb2_obj.location_data.relative_keypoints):
|
||||
keypoints.append(
|
||||
keypoint_module.NormalizedKeypoint(
|
||||
x=elem.x,
|
||||
y=elem.y,
|
||||
label=elem.keypoint_label,
|
||||
score=elem.score,
|
||||
)
|
||||
)
|
||||
|
||||
return Detection(
|
||||
bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
|
||||
pb2_obj.location_data.bounding_box),
|
||||
categories=categories)
|
||||
pb2_obj.location_data.bounding_box
|
||||
),
|
||||
categories=categories,
|
||||
keypoints=keypoints,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
77
mediapipe/tasks/python/components/containers/keypoint.py
Normal file
77
mediapipe/tasks/python/components/containers/keypoint.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keypoint data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
|
||||
from mediapipe.framework.formats import location_data_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_RelativeKeypointProto = location_data_pb2.LocationData.RelativeKeypoint
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedKeypoint:
|
||||
"""A normalized keypoint.
|
||||
|
||||
Normalized keypoint represents a point in 2D space with x, y coordinates.
|
||||
x and y are normalized to [0.0, 1.0] by the image width and height
|
||||
respectively.
|
||||
|
||||
Attributes:
|
||||
x: The x coordinates of the normalized keypoint.
|
||||
y: The y coordinates of the normalized keypoint.
|
||||
label: The optional label of the keypoint.
|
||||
score: The score of the keypoint.
|
||||
"""
|
||||
|
||||
x: Optional[float] = None
|
||||
y: Optional[float] = None
|
||||
label: Optional[str] = None
|
||||
score: Optional[float] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _RelativeKeypointProto:
|
||||
"""Generates a RelativeKeypoint protobuf object."""
|
||||
return _RelativeKeypointProto(
|
||||
x=self.x, y=self.y, keypoint_label=self.label, score=self.score
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _RelativeKeypointProto
|
||||
) -> 'NormalizedKeypoint':
|
||||
"""Creates a `NormalizedKeypoint` object from the given protobuf object."""
|
||||
return NormalizedKeypoint(
|
||||
x=pb2_obj.x,
|
||||
y=pb2_obj.y,
|
||||
label=pb2_obj.keypoint_label,
|
||||
score=pb2_obj.score,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, NormalizedKeypoint):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -92,6 +92,29 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "face_detector_test",
|
||||
srcs = ["face_detector_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:detection_py_pb2",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:bounding_box",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:detections",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:face_detector",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "hand_landmarker_test",
|
||||
srcs = ["hand_landmarker_test.py"],
|
||||
|
|
523
mediapipe/tasks/python/test/vision/face_detector_test.py
Normal file
523
mediapipe/tasks/python/test/vision/face_detector_test.py
Normal file
|
@ -0,0 +1,523 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face detector."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_detector
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
FaceDetectorResult = detections_module.DetectionResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category_module.Category
|
||||
_BoundingBox = bounding_box_module.BoundingBox
|
||||
_Detection = detections_module.Detection
|
||||
_Image = image_module.Image
|
||||
_FaceDetector = face_detector.FaceDetector
|
||||
_FaceDetectorOptions = face_detector.FaceDetectorOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_SHORT_RANGE_BLAZE_FACE_MODEL = 'face_detection_short_range.tflite'
|
||||
_PORTRAIT_IMAGE = 'portrait.jpg'
|
||||
_PORTRAIT_EXPECTED_DETECTION = 'portrait_expected_detection.pbtxt'
|
||||
_PORTRAIT_ROTATED_IMAGE = 'portrait_rotated.jpg'
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION = (
|
||||
'portrait_rotated_expected_detection.pbtxt'
|
||||
)
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_KEYPOINT_ERROR_THRESHOLD = 1e-2
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _get_expected_face_detector_result(file_name: str) -> FaceDetectorResult:
|
||||
face_detection_result_file_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, file_name)
|
||||
)
|
||||
with open(face_detection_result_file_path, 'rb') as f:
|
||||
face_detection_proto = detection_pb2.Detection()
|
||||
text_format.Parse(f.read(), face_detection_proto)
|
||||
face_detection = detections_module.Detection.create_from_pb2(
|
||||
face_detection_proto
|
||||
)
|
||||
return FaceDetectorResult(detections=[face_detection])
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceDetectorTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _PORTRAIT_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _SHORT_RANGE_BLAZE_FACE_MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceDetector.create_from_model_path(self.model_path) as detector:
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
_FaceDetector.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
detector = _FaceDetector.create_from_options(options)
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def _expect_keypoints_correct(self, actual_keypoints, expected_keypoints):
|
||||
self.assertLen(actual_keypoints, len(expected_keypoints))
|
||||
for i in range(len(actual_keypoints)):
|
||||
self.assertAlmostEqual(
|
||||
actual_keypoints[i].x,
|
||||
expected_keypoints[i].x,
|
||||
delta=_KEYPOINT_ERROR_THRESHOLD,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
actual_keypoints[i].y,
|
||||
expected_keypoints[i].y,
|
||||
delta=_KEYPOINT_ERROR_THRESHOLD,
|
||||
)
|
||||
|
||||
def _expect_face_detector_results_correct(
|
||||
self, actual_results, expected_results
|
||||
):
|
||||
self.assertLen(actual_results.detections, len(expected_results.detections))
|
||||
for i in range(len(actual_results.detections)):
|
||||
actual_bbox = actual_results.detections[i].bounding_box
|
||||
expected_bbox = expected_results.detections[i].bounding_box
|
||||
self.assertEqual(actual_bbox, expected_bbox)
|
||||
self.assertNotEmpty(actual_results.detections[i].keypoints)
|
||||
self._expect_keypoints_correct(
|
||||
actual_results.detections[i].keypoints,
|
||||
expected_results.detections[i].keypoints,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
|
||||
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
|
||||
)
|
||||
def test_detect(self, model_file_type, expected_detection_result_file):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
detector = _FaceDetector.create_from_options(options)
|
||||
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
expected_detection_result_file
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
|
||||
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self, model_file_type, expected_detection_result_file
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
expected_detection_result_file
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_detect_succeeds_with_rotated_image(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _PORTRAIT_ROTATED_IMAGE)
|
||||
)
|
||||
)
|
||||
# Rotated input image.
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
# Load a test image with no faces.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path)
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(test_image)
|
||||
self.assertEmpty(detection_result.detections)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _FaceDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _FaceDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
unused_result = detector.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self,
|
||||
model_file_type,
|
||||
test_image_file_name,
|
||||
rotation_degrees,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=base_options, running_mode=_RUNNING_MODE.VIDEO
|
||||
)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, test_image_file_name)
|
||||
)
|
||||
)
|
||||
# Set the image processing options.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation_degrees
|
||||
)
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect_for_video(
|
||||
test_image, timestamp, image_processing_options
|
||||
)
|
||||
# Comparing results.
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
detector.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self,
|
||||
model_file_type,
|
||||
test_image_file_name,
|
||||
rotation_degrees,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: FaceDetectorResult,
|
||||
unused_output_image: _Image,
|
||||
timestamp_ms: int,
|
||||
):
|
||||
self._expect_face_detector_results_correct(
|
||||
result, expected_detection_result
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=base_options,
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result,
|
||||
)
|
||||
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, test_image_file_name)
|
||||
)
|
||||
)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Set the image processing options.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation_degrees
|
||||
)
|
||||
detector.detect_async(test_image, timestamp, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -153,6 +153,26 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "face_detector",
|
||||
srcs = [
|
||||
"face_detector.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:detections",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "face_landmarker",
|
||||
srcs = [
|
||||
|
|
332
mediapipe/tasks/python/vision/face_detector.py
Normal file
332
mediapipe/tasks/python/vision/face_detector.py
Normal file
|
@ -0,0 +1,332 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe face detector task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.tasks.cc.vision.face_detector.proto import face_detector_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
FaceDetectorResult = detections_module.DetectionResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_FaceDetectorGraphOptionsProto = (
|
||||
face_detector_graph_options_pb2.FaceDetectorGraphOptions
|
||||
)
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_DETECTIONS_OUT_STREAM_NAME = 'detections'
|
||||
_DETECTIONS_TAG = 'DETECTIONS'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FaceDetectorOptions:
|
||||
"""Options for the face detector task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the face detector task.
|
||||
running_mode: The running mode of the task. Default to the image mode. Face
|
||||
detector task has three running modes: 1) The image mode for detecting
|
||||
faces on single image inputs. 2) The video mode for detecting faces on the
|
||||
decoded frames of a video. 3) The live stream mode for detecting faces on
|
||||
a live stream of input data, such as from camera.
|
||||
min_detection_confidence: The minimum confidence score for the face
|
||||
detection to be considered successful.
|
||||
min_suppression_threshold: The minimum non-maximum-suppression threshold for
|
||||
face detection to be considered overlapped.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
min_detection_confidence: Optional[float] = None
|
||||
min_suppression_threshold: Optional[float] = None
|
||||
result_callback: Optional[
|
||||
Callable[
|
||||
[detections_module.DetectionResult, image_module.Image, int], None
|
||||
]
|
||||
] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _FaceDetectorGraphOptionsProto:
|
||||
"""Generates an FaceDetectorOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
return _FaceDetectorGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
min_detection_confidence=self.min_detection_confidence,
|
||||
min_suppression_threshold=self.min_suppression_threshold,
|
||||
)
|
||||
|
||||
|
||||
class FaceDetector(base_vision_task_api.BaseVisionTaskApi):
|
||||
"""Class that performs face detection on images."""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'FaceDetector':
|
||||
"""Creates an `FaceDetector` object from a TensorFlow Lite model and the default `FaceDetectorOptions`.
|
||||
|
||||
Note that the created `FaceDetector` instance is in image mode, for
|
||||
detecting faces on single image inputs.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
|
||||
Returns:
|
||||
`FaceDetector` object that's created from the model file and the default
|
||||
`FaceDetectorOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `FaceDetector` object from the provided
|
||||
file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = FaceDetectorOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE
|
||||
)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls, options: FaceDetectorOptions) -> 'FaceDetector':
|
||||
"""Creates the `FaceDetector` object from face detector options.
|
||||
|
||||
Args:
|
||||
options: Options for the face detector task.
|
||||
|
||||
Returns:
|
||||
`FaceDetector` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `FaceDetector` object from
|
||||
`FaceDetectorOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
|
||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
options.result_callback(
|
||||
FaceDetectorResult([]),
|
||||
image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
return
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
detection_result = detections_module.DetectionResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
]
|
||||
)
|
||||
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(
|
||||
detection_result,
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode
|
||||
== _RunningMode.LIVE_STREAM
|
||||
),
|
||||
options.running_mode,
|
||||
packets_callback if options.result_callback else None,
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> FaceDetectorResult:
|
||||
"""Performs face detection on the provided MediaPipe Image.
|
||||
|
||||
Only use this method when the FaceDetector is created with the image
|
||||
running mode.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A face detection result object that contains a list of face detections,
|
||||
each detection has a bounding box that is expressed in the unrotated input
|
||||
frame of reference coordinates system, i.e. in `[0,image_width) x [0,
|
||||
image_height)`, which are the dimensions of the underlying image data.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If face detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||
return FaceDetectorResult([])
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
return detections_module.DetectionResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
]
|
||||
)
|
||||
|
||||
def detect_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> detections_module.DetectionResult:
|
||||
"""Performs face detection on the provided video frames.
|
||||
|
||||
Only use this method when the FaceDetector is created with the video
|
||||
running mode. It's required to provide the video frame's timestamp (in
|
||||
milliseconds) along with the video frame. The input timestamps should be
|
||||
monotonically increasing for adjacent calls of this method.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
A face detection result object that contains a list of face detections,
|
||||
each detection has a bounding box that is expressed in the unrotated input
|
||||
frame of reference coordinates system, i.e. in `[0,image_width) x [0,
|
||||
image_height)`, which are the dimensions of the underlying image data.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If face detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||
return FaceDetectorResult([])
|
||||
detection_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||
)
|
||||
return detections_module.DetectionResult(
|
||||
[
|
||||
detections_module.Detection.create_from_pb2(result)
|
||||
for result in detection_proto_list
|
||||
]
|
||||
)
|
||||
|
||||
def detect_async(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform face detection.
|
||||
|
||||
Only use this method when the FaceDetector is created with the live stream
|
||||
running mode. The input timestamps should be monotonically increasing for
|
||||
adjacent calls of this method. This method will return immediately after the
|
||||
input image is accepted. The results will be available via the
|
||||
`result_callback` provided in the `FaceDetectorOptions`. The
|
||||
`detect_async` method is designed to process live stream data such as camera
|
||||
input. To lower the overall latency, face detector may drop the input
|
||||
images if needed. In other words, it's not guaranteed to have output per
|
||||
input image.
|
||||
|
||||
The `result_callback` provides:
|
||||
- A face detection result object that contains a list of face detections,
|
||||
each detection has a bounding box that is expressed in the unrotated
|
||||
input frame of reference coordinates system,
|
||||
i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions
|
||||
of the underlying image data.
|
||||
- The input image that the face detector runs on.
|
||||
- The input timestamp in milliseconds.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the face
|
||||
detector has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False
|
||||
)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
|
||||
),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
10
mediapipe/tasks/testdata/vision/BUILD
vendored
10
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -31,6 +31,8 @@ mediapipe_files(srcs = [
|
|||
"cat_rotated.jpg",
|
||||
"cat_rotated_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
"cats_and_dogs_mask_dog1.png",
|
||||
"cats_and_dogs_mask_dog2.png",
|
||||
"cats_and_dogs_no_resizing.jpg",
|
||||
"cats_and_dogs_rotated.jpg",
|
||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite",
|
||||
|
@ -70,6 +72,9 @@ mediapipe_files(srcs = [
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"portrait_selfie_segmentation_expected_category_mask.jpg",
|
||||
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
|
||||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
|
||||
"pose.jpg",
|
||||
"pose_detection.tflite",
|
||||
"right_hands.jpg",
|
||||
|
@ -113,6 +118,8 @@ filegroup(
|
|||
"cat_rotated.jpg",
|
||||
"cat_rotated_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
"cats_and_dogs_mask_dog1.png",
|
||||
"cats_and_dogs_mask_dog2.png",
|
||||
"cats_and_dogs_no_resizing.jpg",
|
||||
"cats_and_dogs_rotated.jpg",
|
||||
"fist.jpg",
|
||||
|
@ -129,6 +136,9 @@ filegroup(
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"portrait_selfie_segmentation_expected_category_mask.jpg",
|
||||
"portrait_selfie_segmentation_expected_confidence_mask.jpg",
|
||||
"portrait_selfie_segmentation_landscape_expected_category_mask.jpg",
|
||||
"pose.jpg",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
|
|
6
third_party/BUILD
vendored
6
third_party/BUILD
vendored
|
@ -169,11 +169,7 @@ cmake_external(
|
|||
"-lm",
|
||||
"-lpthread",
|
||||
"-lrt",
|
||||
] + select({
|
||||
"//mediapipe:ios": ["-framework Cocoa"],
|
||||
"//mediapipe:macos": ["-framework Cocoa"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
],
|
||||
shared_libraries = select({
|
||||
"@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES],
|
||||
# Only the shared objects listed here will be linked in the directory
|
||||
|
|
46
third_party/external_files.bzl
vendored
46
third_party/external_files.bzl
vendored
|
@ -67,13 +67,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_BUILD",
|
||||
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_BUILD_orig",
|
||||
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -136,6 +130,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cats_and_dogs_mask_dog1_png",
|
||||
sha256 = "2ab37d56ba1e46e70b3ddbfe35dac51b18b597b76904c68d7d34c7c74c677d4c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog1.png?generation=1678840350058498"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cats_and_dogs_mask_dog2_png",
|
||||
sha256 = "2010850e2dd7f520fe53b9086d70913b6fb53b178cae15a373e5ee7ffb46824a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog2.png?generation=1678840352961684"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg",
|
||||
sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58",
|
||||
|
@ -886,6 +892,24 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg",
|
||||
sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_portrait_selfie_segmentation_expected_confidence_mask_jpg",
|
||||
sha256 = "25b723e90608edaf6ed92f382da703dc904a59c87525b6d271e60d9eed7a90e9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_confidence_mask.jpg?generation=1678606937358235"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg",
|
||||
sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_detection_tflite",
|
||||
sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c",
|
||||
|
@ -1014,8 +1038,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg",
|
||||
sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1661875916766416"],
|
||||
sha256 = "1a2a068287d8bcd4184492485b3dbb95a09b763f4653fd729d14a836147eb383",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1678606942616777"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -1026,8 +1050,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_144_256_3_expected_mask_jpg",
|
||||
sha256 = "cfc699db9670585c04414d0d1a07b289a027ba99d6903d2219f897d34e2c9952",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1661875922646736"],
|
||||
sha256 = "2de433b6e8adabec2aaf80135232db900903ead4f2811c0c9378a6792b2a68b5",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1678606945085676"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user