Merge branch 'google:master' into face-landmarker-python
This commit is contained in:
commit
647db21fc3
|
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
|
|||
# Node dependencies
|
||||
http_archive(
|
||||
name = "build_bazel_rules_nodejs",
|
||||
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
|
||||
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
|
||||
sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
|
||||
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
|
||||
)
|
||||
|
||||
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
||||
|
|
|
@ -1270,6 +1270,50 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "flat_color_image_calculator_proto",
|
||||
srcs = ["flat_color_image_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:color_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "flat_color_image_calculator",
|
||||
srcs = ["flat_color_image_calculator.cc"],
|
||||
deps = [
|
||||
":flat_color_image_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "flat_color_image_calculator_test",
|
||||
srcs = ["flat_color_image_calculator_test.cc"],
|
||||
deps = [
|
||||
":flat_color_image_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "from_image_calculator",
|
||||
srcs = ["from_image_calculator.cc"],
|
||||
|
|
138
mediapipe/calculators/util/flat_color_image_calculator.cc
Normal file
138
mediapipe/calculators/util/flat_color_image_calculator.cc
Normal file
|
@ -0,0 +1,138 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Node;
|
||||
using ::mediapipe::api2::Output;
|
||||
} // namespace
|
||||
|
||||
// A calculator for generating an image filled with a single color.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE (Image, optional)
|
||||
// If provided, the output will have the same size
|
||||
// COLOR (Color proto, optional)
|
||||
// Color to paint the output with. Takes precedence over the equivalent
|
||||
// calculator options.
|
||||
//
|
||||
// Outputs:
|
||||
// IMAGE (Image)
|
||||
// Image filled with the requested color.
|
||||
//
|
||||
// Example useage:
|
||||
// node {
|
||||
// calculator: "FlatColorImageCalculator"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "COLOR:color"
|
||||
// output_stream: "IMAGE:blank_image"
|
||||
// options {
|
||||
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
// color: {
|
||||
// r: 255
|
||||
// g: 255
|
||||
// b: 255
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
class FlatColorImageCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
||||
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
||||
static constexpr Output<Image> kOutImage{"IMAGE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||
|
||||
RET_CHECK(kInImage(cc).IsConnected() ^
|
||||
(options.has_output_height() || options.has_output_width()))
|
||||
<< "Either set IMAGE input stream, or set through options";
|
||||
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
||||
<< "Either set COLOR input stream, or set through options";
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
bool use_dimension_from_option_ = false;
|
||||
bool use_color_from_option_ = false;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
|
||||
|
||||
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
||||
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
||||
use_color_from_option_ = !kInColor(cc).IsConnected();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||
|
||||
int output_height = -1;
|
||||
int output_width = -1;
|
||||
if (use_dimension_from_option_) {
|
||||
output_height = options.output_height();
|
||||
output_width = options.output_width();
|
||||
} else if (!kInImage(cc).IsEmpty()) {
|
||||
const Image& input_image = kInImage(cc).Get();
|
||||
output_height = input_image.height();
|
||||
output_width = input_image.width();
|
||||
} else {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
Color color;
|
||||
if (use_color_from_option_) {
|
||||
color = options.color();
|
||||
} else if (!kInColor(cc).IsEmpty()) {
|
||||
color = kInColor(cc).Get();
|
||||
} else {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
output_width, output_height);
|
||||
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
|
||||
|
||||
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
||||
|
||||
kOutImage(cc).Send(Image(output_frame));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
32
mediapipe/calculators/util/flat_color_image_calculator.proto
Normal file
32
mediapipe/calculators/util/flat_color_image_calculator.proto
Normal file
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/util/color.proto";
|
||||
|
||||
message FlatColorImageCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional FlatColorImageCalculatorOptions ext = 515548435;
|
||||
}
|
||||
|
||||
// Output dimensions.
|
||||
optional int32 output_width = 1;
|
||||
optional int32 output_height = 2;
|
||||
// The color to fill with in the output image.
|
||||
optional Color color = 3;
|
||||
}
|
210
mediapipe/calculators/util/flat_color_image_calculator_test.cc
Normal file
210
mediapipe/calculators/util/flat_color_image_calculator_test.cc
Normal file
|
@ -0,0 +1,210 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kColorTag[] = "COLOR";
|
||||
constexpr int kImageWidth = 256;
|
||||
constexpr int kImageHeight = 256;
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
output_stream: "IMAGE:out_image"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
color: {
|
||||
r: 100,
|
||||
g: 200,
|
||||
b: 255,
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
kImageWidth, kImageHeight);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(outputs.size(), 3);
|
||||
|
||||
for (const auto& packet : outputs) {
|
||||
const auto& image = packet.Get<Image>();
|
||||
EXPECT_EQ(image.width(), kImageWidth);
|
||||
EXPECT_EQ(image.height(), kImageHeight);
|
||||
auto image_frame = image.GetImageFrameSharedPtr();
|
||||
auto* pixel_data = image_frame->PixelData();
|
||||
EXPECT_EQ(pixel_data[0], 100);
|
||||
EXPECT_EQ(pixel_data[1], 200);
|
||||
EXPECT_EQ(pixel_data[2], 255);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "COLOR:color"
|
||||
output_stream: "IMAGE:out_image"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
output_width: 7,
|
||||
output_height: 13,
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
Color color;
|
||||
color.set_r(0);
|
||||
color.set_g(5);
|
||||
color.set_b(0);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||
}
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||
ASSERT_EQ(outputs.size(), 3);
|
||||
|
||||
for (const auto& packet : outputs) {
|
||||
const auto& image = packet.Get<Image>();
|
||||
EXPECT_EQ(image.width(), 7);
|
||||
EXPECT_EQ(image.height(), 13);
|
||||
auto image_frame = image.GetImageFrameSharedPtr();
|
||||
const uint8_t* pixel_data = image_frame->PixelData();
|
||||
EXPECT_EQ(pixel_data[0], 0);
|
||||
EXPECT_EQ(pixel_data[1], 5);
|
||||
EXPECT_EQ(pixel_data[2], 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "COLOR:color"
|
||||
output_stream: "IMAGE:out_image"
|
||||
)pb");
|
||||
|
||||
Color color;
|
||||
color.set_r(0);
|
||||
color.set_g(5);
|
||||
color.set_b(0);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||
}
|
||||
ASSERT_THAT(runner.Run().message(),
|
||||
HasSubstr("Either set IMAGE input stream"));
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
output_stream: "IMAGE:out_image"
|
||||
)pb");
|
||||
|
||||
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
kImageWidth, kImageHeight);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||
}
|
||||
ASSERT_THAT(runner.Run().message(),
|
||||
HasSubstr("Either set COLOR input stream"));
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
input_stream: "COLOR:color"
|
||||
output_stream: "IMAGE:out_image"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
output_width: 7,
|
||||
output_height: 13,
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
kImageWidth, kImageHeight);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||
}
|
||||
ASSERT_THAT(runner.Run().message(),
|
||||
HasSubstr("Either set IMAGE input stream"));
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
input_stream: "COLOR:color"
|
||||
output_stream: "IMAGE:out_image"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
color: {
|
||||
r: 100,
|
||||
g: 200,
|
||||
b: 255,
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
Color color;
|
||||
color.set_r(0);
|
||||
color.set_g(5);
|
||||
color.set_b(0);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||
}
|
||||
ASSERT_THAT(runner.Run().message(),
|
||||
HasSubstr("Either set COLOR input stream"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
|
|||
'import public "' + join_path + '";',
|
||||
)
|
||||
rewrite_ref = SubsituteCommand(
|
||||
r"mediapipe\\.(" + rewrite_message_regex + ")",
|
||||
r"mediapipe\.(" + rewrite_message_regex + ")",
|
||||
r"mediapipe.\\1",
|
||||
)
|
||||
rewrite_objc = SubsituteCommand(
|
||||
|
|
|
@ -467,6 +467,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:frame_buffer",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:yuv_image",
|
||||
"//mediapipe/util/frame_buffer:frame_buffer_util",
|
||||
"//third_party/libyuv",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
|
|
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Face stylizer dataset library."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
from mediapipe.model_maker.python.vision.core import image_utils
|
||||
|
||||
|
||||
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||
class Dataset(classification_dataset.ClassificationDataset):
|
||||
"""Dataset library for face stylizer fine tuning."""
|
||||
|
||||
@classmethod
|
||||
def from_folder(
|
||||
cls, dirname: str
|
||||
) -> classification_dataset.ClassificationDataset:
|
||||
"""Loads images from the given directory.
|
||||
|
||||
The style image dataset directory is expected to contain one subdirectory
|
||||
whose name represents the label of the style. There can be one or multiple
|
||||
images of the same style in that subdirectory. Supported input image formats
|
||||
include 'jpg', 'jpeg', 'png'.
|
||||
|
||||
Args:
|
||||
dirname: Name of the directory containing the image files.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
data_root = os.path.abspath(dirname)
|
||||
|
||||
# Assumes the image data of the same label are in the same subdirectory,
|
||||
# gets image path and label names.
|
||||
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||
all_image_size = len(all_image_paths)
|
||||
if all_image_size == 0:
|
||||
raise ValueError('Invalid input data directory')
|
||||
if not any(
|
||||
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||
):
|
||||
raise ValueError('No images found under given directory')
|
||||
|
||||
label_names = sorted(
|
||||
name
|
||||
for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name))
|
||||
)
|
||||
all_label_size = len(label_names)
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names)
|
||||
)
|
||||
# Get the style label from the subdirectory name.
|
||||
all_image_labels = [
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
image_ds = path_ds.map(
|
||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64)
|
||||
)
|
||||
|
||||
# Create a dataset of (image, label) pairs
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
logging.info(
|
||||
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||
all_image_size,
|
||||
all_label_size,
|
||||
', '.join(label_names),
|
||||
)
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||
)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TODO: Replace the stylize image dataset with licensed images.
|
||||
self._test_data_dirname = 'testdata'
|
||||
|
||||
def test_from_folder(self):
|
||||
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||
self.assertLen(data, 2)
|
||||
|
||||
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||
dataset.Dataset.from_folder(dirname='invalid')
|
||||
|
||||
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'No images found under given directory'
|
||||
):
|
||||
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 347 KiB |
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 336 KiB |
|
@ -57,6 +57,7 @@ pybind_extension(
|
|||
"//mediapipe/framework/formats:landmark_registration",
|
||||
"//mediapipe/framework/formats:rect_registration",
|
||||
"//mediapipe/modules/objectron/calculators:annotation_registration",
|
||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
|
|||
//
|
||||
// At runtime, such codes are meant to be attached (where applicable) to a
|
||||
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
|
||||
// stringifed error code as value (aka payload). This logic is encapsulated in
|
||||
// stringified error code as value (aka payload). This logic is encapsulated in
|
||||
// the `CreateStatusWithPayload` helper below for convenience.
|
||||
//
|
||||
// The returned status includes:
|
||||
|
|
|
@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
|
|||
auto model_bundle_resources = absl::WrapUnique(
|
||||
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
||||
MP_RETURN_IF_ERROR(
|
||||
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
|
||||
model_bundle_resources->ExtractFilesFromExternalFileProto());
|
||||
return model_bundle_resources;
|
||||
}
|
||||
|
||||
absl::Status
|
||||
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
||||
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
|
||||
if (model_asset_bundle_file_->has_file_name()) {
|
||||
// If the model asset bundle file name is a relative path, searches the file
|
||||
// in a platform-specific location and returns the absolute path on success.
|
||||
|
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
|||
model_asset_bundle_file_handler_->GetFileContent().data();
|
||||
size_t buffer_size =
|
||||
model_asset_bundle_file_handler_->GetFileContent().size();
|
||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
|
||||
&model_files_);
|
||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
|
||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
|
||||
const std::string& filename) const {
|
||||
auto it = model_files_.find(filename);
|
||||
if (it == model_files_.end()) {
|
||||
auto model_files = ListModelFiles();
|
||||
std::string all_model_files =
|
||||
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
|
||||
auto it = files_.find(filename);
|
||||
if (it == files_.end()) {
|
||||
auto files = ListFiles();
|
||||
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
|
||||
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kNotFound,
|
||||
absl::StrFormat("No model file with name: %s. All model files in the "
|
||||
"model asset bundle are: %s.",
|
||||
filename, all_model_files),
|
||||
absl::StrFormat("No file with name: %s. All files in the model asset "
|
||||
"bundle are: %s.",
|
||||
filename, all_files),
|
||||
MediaPipeTasksStatus::kFileNotFoundError);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
|
||||
std::vector<std::string> model_names;
|
||||
for (const auto& [model_name, _] : model_files_) {
|
||||
model_names.push_back(model_name);
|
||||
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
|
||||
std::vector<std::string> file_names;
|
||||
for (const auto& [file_name, _] : files_) {
|
||||
file_names.push_back(file_name);
|
||||
}
|
||||
return model_names;
|
||||
return file_names;
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -28,8 +28,8 @@ namespace core {
|
|||
// The mediapipe task model asset bundle resources class.
|
||||
// A ModelAssetBundleResources object, created from an external file proto,
|
||||
// contains model asset bundle related resources and the method to extract the
|
||||
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
|
||||
// resources are owned by the ModelAssetBundleResources object
|
||||
// tflite models, resource files or model asset bundles for the mediapipe
|
||||
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
|
||||
// callers must keep ModelAssetBundleResources alive while using any of the
|
||||
// resources.
|
||||
class ModelAssetBundleResources {
|
||||
|
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
|
|||
// Returns the model asset bundle resources tag.
|
||||
std::string GetTag() const { return tag_; }
|
||||
|
||||
// Gets the contents of the model file (either tflite model file or model
|
||||
// bundle file) with the provided name. An error is returned if there is no
|
||||
// such model file.
|
||||
absl::StatusOr<absl::string_view> GetModelFile(
|
||||
const std::string& filename) const;
|
||||
// Gets the contents of the model file (either tflite model file, resource
|
||||
// file or model bundle file) with the provided name. An error is returned if
|
||||
// there is no such model file.
|
||||
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
|
||||
|
||||
// Lists all the model file names in the model asset model.
|
||||
std::vector<std::string> ListModelFiles() const;
|
||||
// Lists all the file names in the model asset model.
|
||||
std::vector<std::string> ListFiles() const;
|
||||
|
||||
private:
|
||||
// Constructor.
|
||||
|
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
|
|||
const std::string& tag,
|
||||
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||
|
||||
// Extracts the model files (either tflite model file or model bundle file)
|
||||
// from the external file proto.
|
||||
absl::Status ExtractModelFilesFromExternalFileProto();
|
||||
// Extracts the model files (either tflite model file, resource file or model
|
||||
// bundle file) from the external file proto.
|
||||
absl::Status ExtractFilesFromExternalFileProto();
|
||||
|
||||
// The model asset bundle resources tag.
|
||||
const std::string tag_;
|
||||
|
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
|
|||
// The ExternalFileHandler for the model asset bundle.
|
||||
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
||||
|
||||
// The model files bundled in model asset bundle, as a map with the filename
|
||||
// The files bundled in model asset bundle, as a map with the filename
|
||||
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
||||
// a pointer to the file contents as value. Each model file can be either
|
||||
// a TFLite model file or a model bundle file for sub-task.
|
||||
absl::flat_hash_map<std::string, absl::string_view> model_files_;
|
||||
// a pointer to the file contents as value. Each file can be either a TFLite
|
||||
// model file, resource file or a model bundle file for sub-task.
|
||||
absl::flat_hash_map<std::string, absl::string_view> files_;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
#endif // _WIN32
|
||||
|
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
||||
.status());
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||
MP_EXPECT_OK(
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status_or_model_bundle_file =
|
||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
|
||||
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
|
||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||
|
||||
// Creates sub-task model asset bundle resources.
|
||||
|
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(hand_landmaker_model_file)));
|
||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||
->GetModelFile("dummy_hand_detector.tflite")
|
||||
->GetFile("dummy_hand_detector.tflite")
|
||||
.status());
|
||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||
->GetModelFile("dummy_hand_landmarker.tflite")
|
||||
->GetFile("dummy_hand_landmarker.tflite")
|
||||
.status());
|
||||
}
|
||||
|
||||
|
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
|
|||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status_or_model_bundle_file =
|
||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
|
||||
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
|
||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||
|
||||
// Verify tflite model works.
|
||||
|
@ -200,12 +196,12 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
|
|||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
|
||||
auto status = model_bundle_resources->GetFile("not_found.task").status();
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr(
|
||||
"No model file with name: not_found.task. All model files in "
|
||||
"the model asset bundle are: "));
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
testing::HasSubstr("No file with name: not_found.task. All files in "
|
||||
"the model asset bundle are: "));
|
||||
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||
testing::Optional(absl::Cord(
|
||||
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
|
||||
|
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
|
|||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||
std::move(model_file)));
|
||||
auto model_files = model_bundle_resources->ListModelFiles();
|
||||
auto model_files = model_bundle_resources->ListFiles();
|
||||
std::vector<std::string> expected_model_files = {
|
||||
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
||||
std::sort(model_files.begin(), model_files.end());
|
||||
|
|
|
@ -64,7 +64,7 @@ class ModelMetadataPopulator {
|
|||
// Loads associated files into the TFLite FlatBuffer model. The input is a map
|
||||
// of {filename, file contents}.
|
||||
//
|
||||
// Warning: this method removes any previoulsy present associated files.
|
||||
// Warning: this method removes any previously present associated files.
|
||||
// Calling this method multiple time removes any associated files from
|
||||
// previous calls, so this method should usually be called only once.
|
||||
void LoadAssociatedFiles(
|
||||
|
|
|
@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
|
|||
|
||||
// Using pybind11 type conversions to convert between Python and native
|
||||
// C++ types. There are other options to provide access to native Python types
|
||||
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
|
||||
// Type converstions is recommended by pybind11, though the main downside
|
||||
// in C++ and vice versa. See the pybind 11 instruction [1] for more details.
|
||||
// Type conversions is recommended by pybind11, though the main downside
|
||||
// is that a copy of the data must be made on every Python to C++ transition:
|
||||
// this is needed since the C++ and Python versions of the same type generally
|
||||
// won’t have the same memory layout.
|
||||
|
|
|
@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
|
|||
auto metadata = metadata_builder.Finish();
|
||||
FinishModelMetadataBuffer(builder, metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
|
|||
auto metadata = metadata_builder.Finish();
|
||||
builder.Finish(metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version and triggers error.
|
||||
// Gets the minimum metadata parser version and triggers error.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_associated_files(associated_files);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
|
|||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||
CreateModelWithMetadata(tensors, builder);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
|
|||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||
CreateModelWithMetadata(tensors, builder);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
|
|||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||
CreateModelWithMetadata(tensors, builder);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
|
|||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||
CreateModelWithMetadata(tensors, builder);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_associated_files(associated_files);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
|
|||
metadata_builder.add_associated_files(associated_files);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
|
|||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
// Gets the minimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
|
|
|
@ -34,7 +34,7 @@ constexpr char kTestSPModelPath[] =
|
|||
|
||||
std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer(
|
||||
absl::string_view model_path) {
|
||||
// We are using `LoadBinaryContent()` instead of loading the model direclty
|
||||
// We are using `LoadBinaryContent()` instead of loading the model directly
|
||||
// via `SentencePieceTokenizer` so that the file can be located on Windows
|
||||
std::string buffer = LoadBinaryContent(kTestSPModelPath);
|
||||
return absl::make_unique<SentencePieceTokenizer>(buffer.data(),
|
||||
|
|
|
@ -60,6 +60,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:external_file_handler",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
|
||||
|
@ -69,6 +70,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/external_file_handler.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||
|
@ -41,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT";
|
|||
static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
|
||||
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
|
||||
static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY";
|
||||
static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS";
|
||||
|
||||
using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
|
||||
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
|
||||
using ::mediapipe::tasks::vision::face_geometry::proto::
|
||||
GeometryPipelineMetadata;
|
||||
|
||||
// A calculator that renders a visual effect for multiple faces.
|
||||
absl::Status SanityCheck(CalculatorContract* cc) {
|
||||
if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^
|
||||
cc->Inputs().HasTag(kMultiFaceLandmarksTag))) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Only one of %s and %s can be set at a time.",
|
||||
kFaceLandmarksTag, kMultiFaceLandmarksTag));
|
||||
}
|
||||
if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^
|
||||
cc->Outputs().HasTag(kMultiFaceGeometryTag))) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Only one of %s and %s can be set at a time.",
|
||||
kFaceGeometryTag, kMultiFaceGeometryTag));
|
||||
}
|
||||
if (cc->Inputs().HasTag(kFaceLandmarksTag) !=
|
||||
cc->Outputs().HasTag(kFaceGeometryTag)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"%s and %s must both be set or neither be set and a time.",
|
||||
kFaceLandmarksTag, kFaceGeometryTag));
|
||||
}
|
||||
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) !=
|
||||
cc->Outputs().HasTag(kMultiFaceGeometryTag)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"%s and %s must both be set or neither be set and a time.",
|
||||
kMultiFaceLandmarksTag, kMultiFaceGeometryTag));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// A calculator that renders a visual effect for multiple faces. Support single
|
||||
// face landmarks or multiple face landmarks.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE_SIZE (`std::pair<int, int>`, required):
|
||||
|
@ -58,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
|
|||
// ratio. If used as-is, the resulting face geometry visualization should be
|
||||
// happening on a frame with the same ratio as well.
|
||||
//
|
||||
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, required):
|
||||
// A vector of face landmark lists.
|
||||
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, optional):
|
||||
// A vector of face landmark lists. If connected, the output stream
|
||||
// MULTI_FACE_GEOMETRY must be connected.
|
||||
// FACE_LANDMARKS (NormalizedLandmarkList, optional):
|
||||
// A NormalizedLandmarkList of single face landmark lists. If connected, the
|
||||
// output stream FACE_GEOMETRY must be connected.
|
||||
//
|
||||
// Input side packets:
|
||||
// ENVIRONMENT (`proto::Environment`, required)
|
||||
|
@ -67,8 +110,10 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
|
|||
// as well as virtual camera parameters.
|
||||
//
|
||||
// Output:
|
||||
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, required):
|
||||
// A vector of face geometry data.
|
||||
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, optional):
|
||||
// A vector of face geometry data if MULTI_FACE_LANDMARKS is connected .
|
||||
// FACE_GEOMETRY (FaceGeometry, optional):
|
||||
// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected.
|
||||
//
|
||||
// Options:
|
||||
// metadata_file (`ExternalFile`, optional):
|
||||
|
@ -81,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
|||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
|
||||
MP_RETURN_IF_ERROR(SanityCheck(cc));
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||
cc->Inputs()
|
||||
.Tag(kMultiFaceLandmarksTag)
|
||||
.Set<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
|
||||
|
||||
return absl::OkStatus();
|
||||
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
|
||||
cc->Inputs()
|
||||
.Tag(kMultiFaceLandmarksTag)
|
||||
.Set<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
cc->Inputs()
|
||||
.Tag(kFaceLandmarksTag)
|
||||
.Set<mediapipe::NormalizedLandmarkList>();
|
||||
cc->Outputs().Tag(kFaceGeometryTag).Set<FaceGeometry>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
|
@ -112,7 +165,6 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
|||
ASSIGN_OR_RETURN(geometry_pipeline_,
|
||||
CreateGeometryPipeline(environment, metadata),
|
||||
_ << "Failed to create a geometry pipeline!");
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -121,32 +173,54 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
|||
// to have a non-empty packet. In case this requirement is not met, there's
|
||||
// nothing to be processed at the current timestamp.
|
||||
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() ||
|
||||
cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) {
|
||||
(cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty() &&
|
||||
cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty())) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const auto& image_size =
|
||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||
const auto& multi_face_landmarks =
|
||||
cc->Inputs()
|
||||
.Tag(kMultiFaceLandmarksTag)
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||
|
||||
auto multi_face_geometry = absl::make_unique<std::vector<FaceGeometry>>();
|
||||
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
|
||||
const auto& multi_face_landmarks =
|
||||
cc->Inputs()
|
||||
.Tag(kMultiFaceLandmarksTag)
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
*multi_face_geometry,
|
||||
geometry_pipeline_->EstimateFaceGeometry(
|
||||
multi_face_landmarks, //
|
||||
/*frame_width*/ image_size.first,
|
||||
/*frame_height*/ image_size.second),
|
||||
_ << "Failed to estimate face geometry for multiple faces!");
|
||||
auto multi_face_geometry = absl::make_unique<std::vector<FaceGeometry>>();
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kMultiFaceGeometryTag)
|
||||
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
|
||||
multi_face_geometry.release())
|
||||
.At(cc->InputTimestamp()));
|
||||
ASSIGN_OR_RETURN(
|
||||
*multi_face_geometry,
|
||||
geometry_pipeline_->EstimateFaceGeometry(
|
||||
multi_face_landmarks, //
|
||||
/*frame_width*/ image_size.first,
|
||||
/*frame_height*/ image_size.second),
|
||||
_ << "Failed to estimate face geometry for multiple faces!");
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kMultiFaceGeometryTag)
|
||||
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
|
||||
multi_face_geometry.release())
|
||||
.At(cc->InputTimestamp()));
|
||||
} else {
|
||||
const auto& face_landmarks =
|
||||
cc->Inputs()
|
||||
.Tag(kMultiFaceLandmarksTag)
|
||||
.Get<mediapipe::NormalizedLandmarkList>();
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
std::vector<FaceGeometry> multi_face_geometry,
|
||||
geometry_pipeline_->EstimateFaceGeometry(
|
||||
{face_landmarks}, //
|
||||
/*frame_width*/ image_size.first,
|
||||
/*frame_height*/ image_size.second),
|
||||
_ << "Failed to estimate face geometry for multiple faces!");
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kFaceGeometryTag)
|
||||
.AddPacket(mediapipe::MakePacket<FaceGeometry>(multi_face_geometry[0])
|
||||
.At(cc->InputTimestamp()));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -23,6 +24,16 @@ mediapipe_proto_library(
|
|||
srcs = ["environment.proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
base_name = "face_geometry",
|
||||
include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"],
|
||||
types = [
|
||||
"::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry",
|
||||
"::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>",
|
||||
],
|
||||
deps = [":face_geometry_cc_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "face_geometry_proto",
|
||||
srcs = ["face_geometry.proto"],
|
||||
|
|
|
@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
options->mutable_face_detector_graph_options();
|
||||
if (!face_detector_graph_options->base_options().has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto face_detector_file,
|
||||
resources.GetModelFile(kFaceDetectorTFLiteName));
|
||||
resources.GetFile(kFaceDetectorTFLiteName));
|
||||
SetExternalFile(face_detector_file,
|
||||
face_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
|
@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
if (!face_landmarks_detector_graph_options->base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
|
||||
resources.GetModelFile(kFaceLandmarksDetectorTFLiteName));
|
||||
resources.GetFile(kFaceLandmarksDetectorTFLiteName));
|
||||
SetExternalFile(
|
||||
face_landmarks_detector_file,
|
||||
face_landmarks_detector_graph_options->mutable_base_options()
|
||||
|
@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
|
||||
absl::StatusOr<absl::string_view> face_blendshape_model =
|
||||
resources.GetModelFile(kFaceBlendshapeTFLiteName);
|
||||
resources.GetFile(kFaceBlendshapeTFLiteName);
|
||||
if (face_blendshape_model.ok()) {
|
||||
SetExternalFile(*face_blendshape_model,
|
||||
face_landmarks_detector_graph_options
|
||||
|
@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
|||
// Set the face geometry metdata file for
|
||||
// FaceGeometryFromLandmarksGraph.
|
||||
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
|
||||
model_asset_bundle_resources->GetModelFile(
|
||||
model_asset_bundle_resources->GetFile(
|
||||
kFaceGeometryPipelineMetadataName));
|
||||
SetExternalFile(face_geometry_pipeline_metadata_file,
|
||||
sc->MutableOptions<FaceLandmarkerGraphOptions>()
|
||||
|
|
|
@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
GestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
|
||||
resources.GetModelFile(kHandLandmarkerBundleAssetName));
|
||||
resources.GetFile(kHandLandmarkerBundleAssetName));
|
||||
auto* hand_landmarker_graph_options =
|
||||
options->mutable_hand_landmarker_graph_options();
|
||||
SetExternalFile(hand_landmarker_file,
|
||||
|
@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto hand_gesture_recognizer_file,
|
||||
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
|
||||
ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file,
|
||||
resources.GetFile(kHandGestureRecognizerBundleAssetName));
|
||||
auto* hand_gesture_recognizer_graph_options =
|
||||
options->mutable_hand_gesture_recognizer_graph_options();
|
||||
SetExternalFile(hand_gesture_recognizer_file,
|
||||
|
@ -127,7 +126,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
->mutable_acceleration()
|
||||
->mutable_xnnpack();
|
||||
LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets "
|
||||
<< "HandGestureRecognizerGraph acceleartion to Xnnpack.";
|
||||
<< "HandGestureRecognizerGraph acceleration to Xnnpack.";
|
||||
}
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()
|
||||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
|
|
|
@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
HandGestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
|
||||
resources.GetModelFile(kGestureEmbedderTFLiteName));
|
||||
resources.GetFile(kGestureEmbedderTFLiteName));
|
||||
auto* gesture_embedder_graph_options =
|
||||
options->mutable_gesture_embedder_graph_options();
|
||||
SetExternalFile(gesture_embedder_file,
|
||||
|
@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
options->base_options(),
|
||||
gesture_embedder_graph_options->mutable_base_options());
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto canned_gesture_classifier_file,
|
||||
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
|
||||
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
|
||||
resources.GetFile(kCannedGestureClassifierTFLiteName));
|
||||
auto* canned_gesture_classifier_graph_options =
|
||||
options->mutable_canned_gesture_classifier_graph_options();
|
||||
SetExternalFile(
|
||||
|
@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
canned_gesture_classifier_graph_options->mutable_base_options());
|
||||
|
||||
const auto custom_gesture_classifier_file =
|
||||
resources.GetModelFile(kCustomGestureClassifierTFLiteName);
|
||||
resources.GetFile(kCustomGestureClassifierTFLiteName);
|
||||
if (custom_gesture_classifier_file.ok()) {
|
||||
has_custom_gesture_classifier = true;
|
||||
auto* custom_gesture_classifier_graph_options =
|
||||
|
|
|
@ -101,7 +101,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
|||
// three running modes:
|
||||
// 1) Image mode for detecting hand landmarks on single image inputs. Users
|
||||
// provide mediapipe::Image to the `Detect` method, and will receive the
|
||||
// deteced hand landmarks results as the return value.
|
||||
// detected hand landmarks results as the return value.
|
||||
// 2) Video mode for detecting hand landmarks on the decoded frames of a
|
||||
// video. Users call `DetectForVideo` method, and will receive the detected
|
||||
// hand landmarks results as the return value.
|
||||
|
|
|
@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
options->mutable_hand_detector_graph_options();
|
||||
if (!hand_detector_graph_options->base_options().has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto hand_detector_file,
|
||||
resources.GetModelFile(kHandDetectorTFLiteName));
|
||||
resources.GetFile(kHandDetectorTFLiteName));
|
||||
SetExternalFile(hand_detector_file,
|
||||
hand_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
|
@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
if (!hand_landmarks_detector_graph_options->base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
|
||||
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
|
||||
resources.GetFile(kHandLandmarksDetectorTFLiteName));
|
||||
SetExternalFile(
|
||||
hand_landmarks_detector_file,
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()
|
||||
|
|
|
@ -409,7 +409,7 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// - Accepts CPU input image and a vector of hand rect RoIs to detect the
|
||||
// multiple hands landmarks enclosed by the RoIs. Output vectors of
|
||||
// hand landmarks related results, where each element in the vectors
|
||||
// corrresponds to the result of the same hand.
|
||||
// corresponds to the result of the same hand.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
|
|
|
@ -52,7 +52,7 @@ constexpr char kMobileNetV3Embedder[] =
|
|||
constexpr double kSimilarityTolerancy = 1e-6;
|
||||
|
||||
// Utility function to check the sizes, head_index and head_names of a result
|
||||
// procuded by kMobileNetV3Embedder.
|
||||
// procduced by kMobileNetV3Embedder.
|
||||
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
|
||||
EXPECT_EQ(result.embeddings.size(), 1);
|
||||
EXPECT_EQ(result.embeddings[0].head_index, 0);
|
||||
|
|
|
@ -25,6 +25,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_segmenter_graph",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
|
@ -34,10 +35,13 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
||||
import "mediapipe/util/label_map.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks";
|
||||
option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto";
|
||||
|
||||
message TensorsToSegmentationCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
||||
|
|
|
@ -15,15 +15,21 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -112,6 +118,39 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
|
|||
return options_proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
|
||||
const CalculatorGraphConfig& graph_config) {
|
||||
bool found_tensor_to_segmentation_calculator = false;
|
||||
std::vector<std::string> labels;
|
||||
for (const auto& node : graph_config.node()) {
|
||||
if (node.calculator() ==
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator") {
|
||||
if (!found_tensor_to_segmentation_calculator) {
|
||||
found_tensor_to_segmentation_calculator = true;
|
||||
} else {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kFailedPrecondition,
|
||||
"The graph has more than one "
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator."));
|
||||
}
|
||||
TensorsToSegmentationCalculatorOptions options =
|
||||
node.options().GetExtension(
|
||||
TensorsToSegmentationCalculatorOptions::ext);
|
||||
if (!options.label_items().empty()) {
|
||||
for (int i = 0; i < options.label_items_size(); ++i) {
|
||||
if (!options.label_items().contains(i)) {
|
||||
return absl::Status(CreateStatusWithPayload(
|
||||
absl::StatusCode::kFailedPrecondition,
|
||||
absl::StrFormat("The lablemap have no expected key: %d.", i)));
|
||||
}
|
||||
labels.push_back(options.label_items().at(i).name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||
|
@ -140,13 +179,22 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
|
||||
auto image_segmenter =
|
||||
core::VisionTaskApiFactory::Create<ImageSegmenter,
|
||||
ImageSegmenterGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
if (!image_segmenter.ok()) {
|
||||
return image_segmenter.status();
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
(*image_segmenter)->labels_,
|
||||
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
|
||||
return image_segmenter;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||
|
|
|
@ -189,6 +189,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
|
||||
// Shuts down the ImageSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
||||
// Get the category label list of the ImageSegmenter can recognize. For
|
||||
// CATEGORY_MASK type, the index in the category mask corresponds to the
|
||||
// category in the label list. For CONFIDENCE_MASK type, the output mask list
|
||||
// at index corresponds to the category in the label list.
|
||||
//
|
||||
// If there is no labelmap provided in the model file, empty label list is
|
||||
// returned.
|
||||
std::vector<std::string> GetLabels() { return labels_; }
|
||||
|
||||
private:
|
||||
std::vector<std::string> labels_;
|
||||
};
|
||||
|
||||
} // namespace image_segmenter
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
|
@ -71,6 +72,13 @@ constexpr float kGoldenMaskSimilarity = 0.98;
|
|||
// 20 means class index 2, etc.
|
||||
constexpr int kGoldenMaskMagnificationFactor = 10;
|
||||
|
||||
constexpr std::array<absl::string_view, 21> kDeeplabLabelNames = {
|
||||
"background", "aeroplane", "bicycle", "bird", "boat",
|
||||
"bottle", "bus", "car", "cat", "chair",
|
||||
"cow", "dining table", "dog", "horse", "motorbike",
|
||||
"person", "potted plant", "sheep", "sofa", "train",
|
||||
"tv"};
|
||||
|
||||
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1
|
||||
// as expected outputs are stored in CV_8UC1, so this conversion allows to do
|
||||
// fair comparison.
|
||||
|
@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
|
|||
"channels = 3 or 4."));
|
||||
}
|
||||
|
||||
TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
const auto& labels = segmenter->GetLabels();
|
||||
ASSERT_FALSE(labels.empty());
|
||||
ASSERT_EQ(labels.size(), kDeeplabLabelNames.size());
|
||||
for (int i = 0; i < labels.size(); ++i) {
|
||||
EXPECT_EQ(labels[i], kDeeplabLabelNames[i]);
|
||||
}
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||
|
|
|
@ -108,31 +108,31 @@ std::vector<DetectionProto>
|
|||
GenerateMobileSsdNoImageResizingFullExpectedResults() {
|
||||
return {ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.6328125
|
||||
score: 0.6210937
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||
bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.59765625
|
||||
score: 0.609375
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 }
|
||||
bounding_box { xmin: 150 ymin: 78 width: 104 height: 223 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.5
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 }
|
||||
bounding_box { xmin: 64 ymin: 199 width: 42 height: 101 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "dog"
|
||||
score: 0.48828125
|
||||
score: 0.5
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 12 ymin: 110 width: 153 height: 193 }
|
||||
bounding_box { xmin: 14 ymin: 110 width: 153 height: 193 }
|
||||
})pb")};
|
||||
}
|
||||
|
||||
|
@ -268,7 +268,7 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
|||
options->running_mode = running_mode;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||
int64 timestamp_ms) {};
|
||||
int64_t timestamp_ms) {};
|
||||
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
|
||||
ObjectDetector::Create(std::move(options));
|
||||
EXPECT_EQ(object_detector.status().code(),
|
||||
|
@ -381,28 +381,28 @@ TEST_F(ImageModeTest, Succeeds) {
|
|||
score: 0.69921875
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 }
|
||||
bounding_box { xmin: 608 ymin: 164 width: 381 height: 432 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.64453125
|
||||
score: 0.65625
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 }
|
||||
bounding_box { xmin: 57 ymin: 398 width: 386 height: 196 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.51171875
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 }
|
||||
bounding_box { xmin: 256 ymin: 394 width: 173 height: 202 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.48828125
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 }
|
||||
bounding_box { xmin: 360 ymin: 195 width: 330 height: 412 }
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
|
@ -484,10 +484,10 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
|||
results,
|
||||
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.6531269142
|
||||
score: 0.650467276
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 }
|
||||
bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 }
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
|
@ -507,9 +507,9 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
|||
GenerateMobileSsdNoImageResizingFullExpectedResults();
|
||||
|
||||
ExpectApproximatelyEqual(
|
||||
results, ConvertToDetectionResult({full_expected_results[0],
|
||||
full_expected_results[1],
|
||||
full_expected_results[2]}));
|
||||
results, ConvertToDetectionResult(
|
||||
{full_expected_results[0], full_expected_results[1],
|
||||
full_expected_results[2], full_expected_results[3]}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||
|
@ -685,7 +685,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
const Image& image, int64_t timestamp_ms) {};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
|
@ -716,7 +716,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
const Image& image, int64_t timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
|
||||
|
@ -742,13 +742,13 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
std::vector<ObjectDetectorResult> detection_results;
|
||||
std::vector<std::pair<int, int>> image_sizes;
|
||||
std::vector<int64> timestamps;
|
||||
std::vector<int64_t> timestamps;
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
options->result_callback =
|
||||
[&detection_results, &image_sizes, ×tamps](
|
||||
absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
|
||||
int64 timestamp_ms) {
|
||||
int64_t timestamp_ms) {
|
||||
MP_ASSERT_OK(detections.status());
|
||||
detection_results.push_back(std::move(detections).value());
|
||||
image_sizes.push_back({image.width(), image.height()});
|
||||
|
@ -775,7 +775,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
EXPECT_EQ(image_size.first, image.width());
|
||||
EXPECT_EQ(image_size.second, image.height());
|
||||
}
|
||||
int64 timestamp_ms = -1;
|
||||
int64_t timestamp_ms = -1;
|
||||
for (const auto& timestamp : timestamps) {
|
||||
EXPECT_GT(timestamp, timestamp_ms);
|
||||
timestamp_ms = timestamp;
|
||||
|
|
53
mediapipe/tasks/cc/vision/pose_detector/BUILD
Normal file
53
mediapipe/tasks/cc/vision/pose_detector/BUILD
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "pose_detector_graph",
|
||||
srcs = ["pose_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_projection_calculator",
|
||||
"//mediapipe/calculators/util:detection_transformation_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator",
|
||||
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
354
mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc
Normal file
354
mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc
Normal file
|
@ -0,0 +1,354 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_detector {
|
||||
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::Tensor;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::vision::pose_detector::proto::
|
||||
PoseDetectorGraphOptions;
|
||||
|
||||
namespace {
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kAnchorsTag[] = "ANCHORS";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kNormRectsTag[] = "NORM_RECTS";
|
||||
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
||||
constexpr char kPoseRectsTag[] = "POSE_RECTS";
|
||||
constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS";
|
||||
constexpr char kMatrixTag[] = "MATRIX";
|
||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||
|
||||
struct PoseDetectionOuts {
|
||||
Source<std::vector<Detection>> pose_detections;
|
||||
Source<std::vector<NormalizedRect>> pose_rects;
|
||||
Source<std::vector<NormalizedRect>> expanded_pose_rects;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureSsdAnchorsCalculator(
|
||||
mediapipe::SsdAnchorsCalculatorOptions* options) {
|
||||
// Dervied from
|
||||
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
|
||||
options->set_num_layers(5);
|
||||
options->set_min_scale(0.1484375);
|
||||
options->set_max_scale(0.75);
|
||||
options->set_input_size_height(224);
|
||||
options->set_input_size_width(224);
|
||||
options->set_anchor_offset_x(0.5);
|
||||
options->set_anchor_offset_y(0.5);
|
||||
options->add_strides(8);
|
||||
options->add_strides(16);
|
||||
options->add_strides(32);
|
||||
options->add_strides(32);
|
||||
options->add_strides(32);
|
||||
options->add_aspect_ratios(1.0);
|
||||
options->set_fixed_anchor_size(true);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureTensorsToDetectionsCalculator(
|
||||
const PoseDetectorGraphOptions& tasks_options,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
// Dervied from
|
||||
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
|
||||
options->set_num_classes(1);
|
||||
options->set_num_boxes(2254);
|
||||
options->set_num_coords(12);
|
||||
options->set_box_coord_offset(0);
|
||||
options->set_keypoint_coord_offset(4);
|
||||
options->set_num_keypoints(4);
|
||||
options->set_num_values_per_keypoint(2);
|
||||
options->set_sigmoid_score(true);
|
||||
options->set_score_clipping_thresh(100.0);
|
||||
options->set_reverse_output_order(true);
|
||||
options->set_min_score_thresh(tasks_options.min_detection_confidence());
|
||||
options->set_x_scale(224.0);
|
||||
options->set_y_scale(224.0);
|
||||
options->set_w_scale(224.0);
|
||||
options->set_h_scale(224.0);
|
||||
}
|
||||
|
||||
void ConfigureNonMaxSuppressionCalculator(
|
||||
const PoseDetectorGraphOptions& tasks_options,
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions* options) {
|
||||
options->set_min_suppression_threshold(
|
||||
tasks_options.min_suppression_threshold());
|
||||
options->set_overlap_type(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
|
||||
options->set_algorithm(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureDetectionsToRectsCalculator(
|
||||
mediapipe::DetectionsToRectsCalculatorOptions* options) {
|
||||
options->set_rotation_vector_start_keypoint_index(0);
|
||||
options->set_rotation_vector_end_keypoint_index(2);
|
||||
options->set_rotation_vector_target_angle(90);
|
||||
options->set_output_zero_rect_for_empty_detections(true);
|
||||
}
|
||||
|
||||
// TODO: Configuration detection related calculators in pose
|
||||
// detector with model metadata.
|
||||
void ConfigureRectTransformationCalculator(
|
||||
mediapipe::RectTransformationCalculatorOptions* options) {
|
||||
options->set_scale_x(2.6);
|
||||
options->set_scale_y(2.6);
|
||||
options->set_shift_y(-0.5);
|
||||
options->set_square_long(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" performs pose
|
||||
// detection.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection on. If
|
||||
// not provided, whole image is used for pose detection.
|
||||
//
|
||||
// Outputs:
|
||||
// DETECTIONS - std::vector<Detection>
|
||||
// Detected pose with maximum `num_poses` specified in options.
|
||||
// POSE_RECTS - std::vector<NormalizedRect>
|
||||
// Detected pose bounding boxes in normalized coordinates.
|
||||
// EXPANDED_POSE_RECTS - std::vector<NormalizedRect>
|
||||
// Expanded pose bounding boxes in normalized coordinates so that bounding
|
||||
// boxes likely contain the whole pose. This is usually used as RoI for pose
|
||||
// landmarks detection to run on.
|
||||
// IMAGE - Image
|
||||
// The input image that the pose detector runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
// All returned coordinates are in the unrotated and uncropped input image
|
||||
// coordinates system.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "DETECTIONS:palm_detections"
|
||||
// output_stream: "POSE_RECTS:pose_rects"
|
||||
// output_stream: "EXPANDED_POSE_RECTS:expanded_pose_rects"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.pose_detector.proto.PoseDetectorGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "pose_detection.tflite"
|
||||
// }
|
||||
// }
|
||||
// min_detection_confidence: 0.5
|
||||
// num_poses: 2
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class PoseDetectorGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<PoseDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto outs,
|
||||
BuildPoseDetectionSubgraph(
|
||||
sc->Options<PoseDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
|
||||
outs.pose_detections >>
|
||||
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
|
||||
outs.pose_rects >>
|
||||
graph.Out(kPoseRectsTag).Cast<std::vector<NormalizedRect>>();
|
||||
outs.expanded_pose_rects >>
|
||||
graph.Out(kExpandedPoseRectsTag).Cast<std::vector<NormalizedRect>>();
|
||||
outs.image >> graph.Out(kImageTag).Cast<Image>();
|
||||
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<PoseDetectionOuts> BuildPoseDetectionSubgraph(
|
||||
const PoseDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Image preprocessing subgraph to convert image to tensor for the tflite
|
||||
// model.
|
||||
auto& preprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
|
||||
bool use_gpu =
|
||||
components::processors::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
|
||||
model_resources, use_gpu,
|
||||
&preprocessing.GetOptions<
|
||||
components::processors::proto::ImagePreprocessingGraphOptions>()));
|
||||
auto& image_to_tensor_options =
|
||||
*preprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ImagePreprocessingGraphOptions>()
|
||||
.mutable_image_to_tensor_options();
|
||||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_to_tensor_options.set_border_mode(
|
||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||
auto matrix = preprocessing.Out(kMatrixTag);
|
||||
auto image_size = preprocessing.Out(kImageSizeTag);
|
||||
|
||||
// Pose detection model inferece.
|
||||
auto& inference = AddInference(
|
||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||
preprocessed_tensors >> inference.In(kTensorsTag);
|
||||
auto model_output_tensors =
|
||||
inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
|
||||
|
||||
// Generates a single side packet containing a vector of SSD anchors.
|
||||
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
|
||||
ConfigureSsdAnchorsCalculator(
|
||||
&ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>());
|
||||
auto anchors = ssd_anchor.SideOut("");
|
||||
|
||||
// Converts output tensors to Detections.
|
||||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
ConfigureTensorsToDetectionsCalculator(
|
||||
subgraph_options,
|
||||
&tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
|
||||
model_output_tensors >> tensors_to_detections.In(kTensorsTag);
|
||||
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
||||
auto detections = tensors_to_detections.Out(kDetectionsTag);
|
||||
|
||||
// Non maximum suppression removes redundant face detections.
|
||||
auto& non_maximum_suppression =
|
||||
graph.AddNode("NonMaxSuppressionCalculator");
|
||||
ConfigureNonMaxSuppressionCalculator(
|
||||
subgraph_options,
|
||||
&non_maximum_suppression
|
||||
.GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>());
|
||||
detections >> non_maximum_suppression.In("");
|
||||
auto nms_detections = non_maximum_suppression.Out("");
|
||||
|
||||
// Projects detections back into the input image coordinates system.
|
||||
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
|
||||
nms_detections >> detection_projection.In(kDetectionsTag);
|
||||
matrix >> detection_projection.In(kProjectionMatrixTag);
|
||||
Source<std::vector<Detection>> pose_detections =
|
||||
detection_projection.Out(kDetectionsTag).Cast<std::vector<Detection>>();
|
||||
|
||||
if (subgraph_options.has_num_poses()) {
|
||||
// Clip face detections to maximum number of poses.
|
||||
auto& clip_detection_vector_size =
|
||||
graph.AddNode("ClipDetectionVectorSizeCalculator");
|
||||
clip_detection_vector_size
|
||||
.GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>()
|
||||
.set_max_vec_size(subgraph_options.num_poses());
|
||||
pose_detections >> clip_detection_vector_size.In("");
|
||||
pose_detections =
|
||||
clip_detection_vector_size.Out("").Cast<std::vector<Detection>>();
|
||||
}
|
||||
|
||||
// Converts results of pose detection into a rectangle (normalized by image
|
||||
// size) that encloses the face and is rotated such that the line connecting
|
||||
// left eye and right eye is aligned with the X-axis of the rectangle.
|
||||
auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator");
|
||||
ConfigureDetectionsToRectsCalculator(
|
||||
&detections_to_rects
|
||||
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
|
||||
image_size >> detections_to_rects.In(kImageSizeTag);
|
||||
pose_detections >> detections_to_rects.In(kDetectionsTag);
|
||||
auto pose_rects = detections_to_rects.Out(kNormRectsTag)
|
||||
.Cast<std::vector<NormalizedRect>>();
|
||||
|
||||
// Expands and shifts the rectangle that contains the pose so that it's
|
||||
// likely to cover the entire pose.
|
||||
auto& rect_transformation = graph.AddNode("RectTransformationCalculator");
|
||||
ConfigureRectTransformationCalculator(
|
||||
&rect_transformation
|
||||
.GetOptions<mediapipe::RectTransformationCalculatorOptions>());
|
||||
pose_rects >> rect_transformation.In(kNormRectsTag);
|
||||
image_size >> rect_transformation.In(kImageSizeTag);
|
||||
auto expanded_pose_rects =
|
||||
rect_transformation.Out("").Cast<std::vector<NormalizedRect>>();
|
||||
|
||||
// Calculator to convert relative detection bounding boxes to pixel
|
||||
// detection bounding boxes.
|
||||
auto& detection_transformation =
|
||||
graph.AddNode("DetectionTransformationCalculator");
|
||||
detection_projection.Out(kDetectionsTag) >>
|
||||
detection_transformation.In(kDetectionsTag);
|
||||
preprocessing.Out(kImageSizeTag) >>
|
||||
detection_transformation.In(kImageSizeTag);
|
||||
auto pose_pixel_detections =
|
||||
detection_transformation.Out(kPixelDetectionsTag)
|
||||
.Cast<std::vector<Detection>>();
|
||||
|
||||
return PoseDetectionOuts{
|
||||
/* pose_detections= */ pose_pixel_detections,
|
||||
/* pose_rects= */ pose_rects,
|
||||
/* expanded_pose_rects= */ expanded_pose_rects,
|
||||
/* image= */ preprocessing.Out(kImageTag).Cast<Image>()};
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::pose_detector::PoseDetectorGraph);
|
||||
|
||||
} // namespace pose_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,165 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace pose_detector {
|
||||
namespace {
|
||||
|
||||
using ::file::Defaults;
|
||||
using ::file::GetTextProto;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::core::TaskRunner;
|
||||
using ::mediapipe::tasks::vision::pose_detector::proto::
|
||||
PoseDetectorGraphOptions;
|
||||
using ::testing::EqualsProto;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::TestParamInfo;
|
||||
using ::testing::TestWithParam;
|
||||
using ::testing::Values;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPoseDetectionModel[] = "pose_detection.tflite";
|
||||
constexpr char kPortraitImage[] = "pose.jpg";
|
||||
constexpr char kPoseExpectedDetection[] = "pose_expected_detection.pbtxt";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectName[] = "norm_rect";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kDetectionsName[] = "detections";
|
||||
|
||||
constexpr float kPoseDetectionMaxDiff = 0.01;
|
||||
|
||||
// Helper function to create a TaskRunner.
|
||||
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||
absl::string_view model_name) {
|
||||
Graph graph;
|
||||
|
||||
auto& pose_detector_graph =
|
||||
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
|
||||
|
||||
auto options = std::make_unique<PoseDetectorGraphOptions>();
|
||||
options->mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, model_name));
|
||||
options->set_min_detection_confidence(0.6);
|
||||
options->set_min_suppression_threshold(0.3);
|
||||
pose_detector_graph.GetOptions<PoseDetectorGraphOptions>().Swap(
|
||||
options.get());
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
pose_detector_graph.In(kImageTag);
|
||||
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
|
||||
pose_detector_graph.In(kNormRectTag);
|
||||
|
||||
pose_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
Detection GetExpectedPoseDetectionResult(absl::string_view file_name) {
|
||||
Detection detection;
|
||||
CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name),
|
||||
&detection, Defaults()))
|
||||
<< "Expected pose detection result does not exist.";
|
||||
return detection;
|
||||
}
|
||||
|
||||
struct TestParams {
|
||||
// The name of this test, for convenience when displaying test results.
|
||||
std::string test_name;
|
||||
// The filename of pose landmark detection model.
|
||||
std::string pose_detection_model_name;
|
||||
// The filename of test image.
|
||||
std::string test_image_name;
|
||||
// Expected pose detection results.
|
||||
std::vector<Detection> expected_result;
|
||||
};
|
||||
|
||||
class PoseDetectorGraphTest : public testing::TestWithParam<TestParams> {};
|
||||
|
||||
TEST_P(PoseDetectorGraphTest, Succeed) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
NormalizedRect input_norm_rect;
|
||||
input_norm_rect.set_x_center(0.5);
|
||||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto task_runner, CreateTaskRunner(GetParam().pose_detection_model_name));
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
|
||||
MP_ASSERT_OK(output_packets);
|
||||
const std::vector<Detection>& pose_detections =
|
||||
(*output_packets)[kDetectionsName].Get<std::vector<Detection>>();
|
||||
EXPECT_THAT(pose_detections, Pointwise(Approximately(Partially(EqualsProto()),
|
||||
kPoseDetectionMaxDiff),
|
||||
GetParam().expected_result));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoseDetectorGraphTest, PoseDetectorGraphTest,
|
||||
Values(TestParams{.test_name = "DetectPose",
|
||||
.pose_detection_model_name = kPoseDetectionModel,
|
||||
.test_image_name = kPortraitImage,
|
||||
.expected_result = {GetExpectedPoseDetectionResult(
|
||||
kPoseExpectedDetection)}}),
|
||||
[](const TestParamInfo<PoseDetectorGraphTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace pose_detector
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
31
mediapipe/tasks/cc/vision/pose_detector/proto/BUILD
Normal file
31
mediapipe/tasks/cc/vision/pose_detector/proto/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "pose_detector_graph_options_proto",
|
||||
srcs = ["pose_detector_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.pose_detector.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.posedetector.proto";
|
||||
option java_outer_classname = "PoseDetectorGraphOptionsProto";
|
||||
|
||||
message PoseDetectorGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional PoseDetectorGraphOptions ext = 514774813;
|
||||
}
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
|
||||
// successfully detecting a pose in the image.
|
||||
optional float min_detection_confidence = 2 [default = 0.5];
|
||||
|
||||
// IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered
|
||||
// duplicate detections.
|
||||
optional float min_suppression_threshold = 3 [default = 0.5];
|
||||
|
||||
// Maximum number of poses to detect in the image.
|
||||
optional int32 num_poses = 4;
|
||||
}
|
|
@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core;
|
|||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||
import com.google.mediapipe.framework.Graph;
|
||||
|
@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable {
|
|||
}
|
||||
}
|
||||
|
||||
public CalculatorGraphConfig getCalculatorGraphConfig() {
|
||||
return graph.getCalculatorGraphConfig();
|
||||
}
|
||||
|
||||
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
if (!graphStarted.get()) {
|
||||
reportError(
|
||||
|
|
|
@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
|
|
|
@ -197,6 +197,7 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
|
|
|
@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter;
|
|||
import android.content.Context;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
|
@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
|||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator";
|
||||
private boolean hasResultListener = false;
|
||||
private List<String> labels = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||
|
@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
this.hasResultListener = hasResultListener;
|
||||
populateLabels();
|
||||
}
|
||||
/**
|
||||
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
|
||||
*
|
||||
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
|
||||
*/
|
||||
private void populateLabels() {
|
||||
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
|
||||
|
||||
boolean foundTensorsToSegmentation = false;
|
||||
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
|
||||
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
|
||||
if (foundTensorsToSegmentation) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
|
||||
}
|
||||
foundTensorsToSegmentation = true;
|
||||
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
|
||||
node.getOptions()
|
||||
.getExtension(
|
||||
TensorsToSegmentationCalculatorOptionsProto
|
||||
.TensorsToSegmentationCalculatorOptions.ext);
|
||||
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
|
||||
Long labelKey = Long.valueOf(i);
|
||||
if (!options.getLabelItemsMap().containsKey(labelKey)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"The lablemap have no expected key: " + labelKey);
|
||||
}
|
||||
labels.add(options.getLabelItemsMap().get(labelKey).getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
|
||||
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
|
||||
* type, the output mask list at index corresponds to the category in the label list.
|
||||
*
|
||||
* <p>If there is no labelmap provided in the model file, empty label list is returned.
|
||||
*/
|
||||
List<String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
||||
|
|
|
@ -7,6 +7,7 @@ VERS_1.0 {
|
|||
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;
|
||||
|
|
|
@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm
|
|||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -135,6 +136,45 @@ public class ImageSegmenterTest {
|
|||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// }
|
||||
|
||||
@Test
|
||||
public void getLabels_success() throws Exception {
|
||||
final List<String> expectedLabels =
|
||||
Arrays.asList(
|
||||
"background",
|
||||
"aeroplane",
|
||||
"bicycle",
|
||||
"bird",
|
||||
"boat",
|
||||
"bottle",
|
||||
"bus",
|
||||
"car",
|
||||
"cat",
|
||||
"chair",
|
||||
"cow",
|
||||
"dining table",
|
||||
"dog",
|
||||
"horse",
|
||||
"motorbike",
|
||||
"person",
|
||||
"potted plant",
|
||||
"sheep",
|
||||
"sofa",
|
||||
"train",
|
||||
"tv");
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
List<String> actualLabels = imageSegmenter.getLabels();
|
||||
assertThat(actualLabels.size()).isEqualTo(expectedLabels.size());
|
||||
for (int i = 0; i < actualLabels.size(); i++) {
|
||||
assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
|
|
|
@ -7,7 +7,9 @@ package(
|
|||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["metadata_schema.fbs"])
|
||||
exports_files(glob([
|
||||
"*.fbs",
|
||||
]))
|
||||
|
||||
# Generic schema for model metadata.
|
||||
flatbuffer_cc_library(
|
||||
|
@ -24,3 +26,13 @@ flatbuffer_py_library(
|
|||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "image_segmenter_metadata_schema_cc",
|
||||
srcs = ["image_segmenter_metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "image_segmenter_metadata_schema_py",
|
||||
srcs = ["image_segmenter_metadata_schema.fbs"],
|
||||
)
|
||||
|
|
59
mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs
Normal file
59
mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace mediapipe.tasks;
|
||||
|
||||
// Image segmenter metadata contains information specific for the image
|
||||
// segmentation task. The metadata can be added in
|
||||
// SubGraphMetadata.custom_metadata [1] in model metadata.
|
||||
// [1]: https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L685
|
||||
|
||||
// ImageSegmenterOptions.min_parser_version indicates the minimum necessary
|
||||
// image segmenter metadata parser version to fully understand all fields in a
|
||||
// given metadata flatbuffer. This min_parser_version is specific for the
|
||||
// image segmenter metadata defined in this schema file.
|
||||
//
|
||||
// New fields and types will have associated comments with the schema version
|
||||
// for which they were added.
|
||||
//
|
||||
// Schema Semantic version: 1.0.0
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
// fields to the middle of a table.
|
||||
file_identifier "V001";
|
||||
|
||||
// History:
|
||||
// 1.0.0 - Initial version.
|
||||
|
||||
// Supported activation functions.
|
||||
enum Activation: byte {
|
||||
NONE = 0,
|
||||
SIGMOID = 1,
|
||||
SOFTMAX = 2
|
||||
}
|
||||
|
||||
table ImageSegmenterOptions {
|
||||
// The activation function of the output layer in the image segmenter.
|
||||
activation: Activation;
|
||||
|
||||
// The minimum necessary image segmenter metadata parser version to fully
|
||||
// understand all fields in a given metadata flatbuffer. This field is
|
||||
// automaticaly populated by the MetadataPopulator when the metadata is
|
||||
// populated into a TFLite model. This min_parser_version is specific for the
|
||||
// image segmenter metadata defined in this schema file.
|
||||
min_parser_version:string;
|
||||
}
|
||||
|
||||
root_type ImageSegmenterOptions;
|
|
@ -233,7 +233,7 @@ table ImageProperties {
|
|||
//
|
||||
// <Codegen usage>:
|
||||
// Input image tensors: NA.
|
||||
// Output image tensors: parses the values into a data stucture that represents
|
||||
// Output image tensors: parses the values into a data structure that represents
|
||||
// bounding boxes. For example, in the generated wrapper for Android, it returns
|
||||
// the output as android.graphics.Rect objects.
|
||||
enum BoundingBoxType : byte {
|
||||
|
@ -389,7 +389,7 @@ table NormalizationOptions{
|
|||
// mean and std are normalization parameters. Tensor values are normalized
|
||||
// on a per-channel basis, by the formula
|
||||
// (x - mean) / std.
|
||||
// If there is only one value in mean or std, we'll propogate the value to
|
||||
// If there is only one value in mean or std, we'll propagate the value to
|
||||
// all channels.
|
||||
//
|
||||
// Quantized models share the same normalization parameters as their
|
||||
|
@ -526,7 +526,7 @@ table Stats {
|
|||
// Max and min are not currently used in tflite.support codegen. They mainly
|
||||
// serve as references for users to better understand the model. They can also
|
||||
// be used to validate model pre/post processing results.
|
||||
// If there is only one value in max or min, we'll propogate the value to
|
||||
// If there is only one value in max or min, we'll propagate the value to
|
||||
// all channels.
|
||||
|
||||
// Per-channel maximum value of the tensor.
|
||||
|
@ -542,7 +542,7 @@ table Stats {
|
|||
// has four outputs: classes, scores, bounding boxes, and number of detections.
|
||||
// If the four outputs are bundled together using TensorGroup (for example,
|
||||
// named as "detection result"), the codegen tool will generate the class,
|
||||
// `DetectionResult`, which contains the class, score, and bouding box. And the
|
||||
// `DetectionResult`, which contains the class, score, and bounding box. And the
|
||||
// outputs of the model will be converted to a list of `DetectionResults` and
|
||||
// the number of detection. Note that the number of detection is a single
|
||||
// number, therefore is inappropriate for the list of `DetectionResult`.
|
||||
|
@ -624,7 +624,7 @@ table SubGraphMetadata {
|
|||
// A description explains details about what the subgraph does.
|
||||
description:string;
|
||||
|
||||
// Metadata of all input tensors used in this subgraph. It matches extactly
|
||||
// Metadata of all input tensors used in this subgraph. It matches exactly
|
||||
// the input tensors specified by `SubGraph.inputs` in the TFLite
|
||||
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
|
||||
// equal to the number of indices in `SubGraph.inputs`.
|
||||
|
@ -634,7 +634,7 @@ table SubGraphMetadata {
|
|||
// Determines how to process the inputs.
|
||||
input_tensor_metadata:[TensorMetadata];
|
||||
|
||||
// Metadata of all output tensors used in this subgraph. It matches extactly
|
||||
// Metadata of all output tensors used in this subgraph. It matches exactly
|
||||
// the output tensors specified by `SubGraph.outputs` in the TFLite
|
||||
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
|
||||
// equal to the number of indices in `SubGraph.outputs`.
|
||||
|
@ -724,7 +724,7 @@ table ModelMetadata {
|
|||
// number among the versions of all the fields populated and the smallest
|
||||
// compatible version indicated by the file identifier.
|
||||
//
|
||||
// This field is automaticaly populated by the MetadataPopulator when
|
||||
// This field is automatically populated by the MetadataPopulator when
|
||||
// the metadata is populated into a TFLite model.
|
||||
min_parser_version:string;
|
||||
}
|
||||
|
|
|
@ -17,10 +17,13 @@
|
|||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict, Optional
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
|
@ -789,13 +792,43 @@ class MetadataDisplayer(object):
|
|||
return []
|
||||
|
||||
|
||||
def _get_custom_metadata(metadata_buffer: bytes, name: str):
|
||||
"""Gets the custom metadata in metadata_buffer based on the name.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
name: custom metadata name.
|
||||
|
||||
Returns:
|
||||
Index of custom metadata, custom metadata flatbuffer. Returns (None, None)
|
||||
if the custom metadata is not found.
|
||||
"""
|
||||
model_metadata = _metadata_fb.ModelMetadata.GetRootAs(metadata_buffer)
|
||||
subgraph = model_metadata.SubgraphMetadata(0)
|
||||
if subgraph is None or subgraph.CustomMetadataIsNone():
|
||||
return None, None
|
||||
|
||||
for i in range(subgraph.CustomMetadataLength()):
|
||||
custom_metadata = subgraph.CustomMetadata(i)
|
||||
if custom_metadata.Name().decode("utf-8") == name:
|
||||
return i, custom_metadata.DataAsNumpy().tobytes()
|
||||
return None, None
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer):
|
||||
def convert_to_json(
|
||||
metadata_buffer, custom_metadata_schema: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
custom_metadata_schema: A dict of custom metadata schema, in which key is
|
||||
custom metadata name [1], value is the filepath that defines custom
|
||||
metadata schema. For intance, custom_metadata_schema =
|
||||
{"SEGMENTER_METADATA": "metadata/vision_tasks_metadata_schema.fbs"}. [1]:
|
||||
https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L612
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
@ -803,7 +836,6 @@ def convert_to_json(metadata_buffer):
|
|||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
|
||||
opt = _pywrap_flatbuffers.IDLOptions()
|
||||
opt.strict_json = True
|
||||
parser = _pywrap_flatbuffers.Parser(opt)
|
||||
|
@ -811,7 +843,35 @@ def convert_to_json(metadata_buffer):
|
|||
metadata_schema_content = f.read()
|
||||
if not parser.parse(metadata_schema_content):
|
||||
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
|
||||
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
# Json content which may contain binary custom metadata.
|
||||
raw_json_content = _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
if not custom_metadata_schema:
|
||||
return raw_json_content
|
||||
|
||||
json_data = json.loads(raw_json_content)
|
||||
# Gets the custom metadata by name and parse the binary custom metadata into
|
||||
# human readable json content.
|
||||
for name, schema_file in custom_metadata_schema.items():
|
||||
idx, custom_metadata = _get_custom_metadata(metadata_buffer, name)
|
||||
if not custom_metadata:
|
||||
logging.info(
|
||||
"No custom metadata with name %s in metadata flatbuffer.", name
|
||||
)
|
||||
continue
|
||||
_assert_file_exist(schema_file)
|
||||
with _open_file(schema_file, "rb") as f:
|
||||
custom_metadata_schema_content = f.read()
|
||||
if not parser.parse(custom_metadata_schema_content):
|
||||
raise ValueError(
|
||||
"Cannot parse custom metadata schema. Reason: " + parser.error
|
||||
)
|
||||
custom_metadata_json = _pywrap_flatbuffers.generate_text(
|
||||
parser, custom_metadata
|
||||
)
|
||||
json_meta = json_data["subgraph_metadata"][0]["custom_metadata"][idx]
|
||||
json_meta["name"] = name
|
||||
json_meta["data"] = json.loads(custom_metadata_json)
|
||||
return json.dumps(json_data, indent=2)
|
||||
|
||||
|
||||
def _assert_file_exist(filename):
|
||||
|
|
|
@ -50,6 +50,20 @@ py_library(
|
|||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_segmenter",
|
||||
srcs = ["image_segmenter.py"],
|
||||
data = ["//mediapipe/tasks/metadata:image_segmenter_metadata_schema.fbs"],
|
||||
deps = [
|
||||
":metadata_info",
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "object_detector",
|
||||
srcs = ["object_detector.py"],
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and label file to the image segmenter models."""
|
||||
import enum
|
||||
from typing import List, Optional
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import image_segmenter_metadata_schema_py_generated as _segmenter_metadata_fb
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
|
||||
_MODEL_NAME = "ImageSegmenter"
|
||||
_MODEL_DESCRIPTION = (
|
||||
"Semantic image segmentation predicts whether each pixel "
|
||||
"of an image is associated with a certain class."
|
||||
)
|
||||
|
||||
# Metadata Schema file for image segmenter.
|
||||
_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile(
|
||||
"../../../metadata/image_segmenter_metadata_schema.fbs",
|
||||
)
|
||||
|
||||
# Metadata name in custom metadata field. The metadata name is used to get
|
||||
# image segmenter metadata from SubGraphMetadata.custom_metadata and
|
||||
# shouldn't be changed.
|
||||
_METADATA_NAME = "SEGMENTER_METADATA"
|
||||
|
||||
|
||||
class Activation(enum.Enum):
|
||||
NONE = 0
|
||||
SIGMOID = 1
|
||||
SOFTMAX = 2
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer: bytearray) -> str:
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
return metadata.convert_to_json(
|
||||
metadata_buffer,
|
||||
custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE},
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmenterOptionsMd(metadata_info.CustomMetadataMd):
|
||||
"""Image segmenter options metadata."""
|
||||
|
||||
_METADATA_FILE_IDENTIFIER = b"V001"
|
||||
|
||||
def __init__(self, activation: Activation) -> None:
|
||||
"""Creates an ImageSegmenterOptionsMd object.
|
||||
|
||||
Args:
|
||||
activation: activation function of the output layer in the image
|
||||
segmenter.
|
||||
"""
|
||||
self.activation = activation
|
||||
super().__init__(name=_METADATA_NAME)
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||
"""Creates the image segmenter options metadata.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the custom metadata including image
|
||||
segmenter options metadata.
|
||||
"""
|
||||
segmenter_options = _segmenter_metadata_fb.ImageSegmenterOptionsT()
|
||||
segmenter_options.activation = self.activation.value
|
||||
|
||||
# Get the image segmenter options flatbuffer.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(segmenter_options.Pack(b), self._METADATA_FILE_IDENTIFIER)
|
||||
segmenter_options_buf = b.Output()
|
||||
|
||||
# Add the image segmenter options flatbuffer in custom metadata.
|
||||
custom_metadata = _metadata_fb.CustomMetadataT()
|
||||
custom_metadata.name = self.name
|
||||
custom_metadata.data = segmenter_options_buf
|
||||
return custom_metadata
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata for image segmenter."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
input_norm_std: List[float],
|
||||
labels: Optional[metadata_writer.Labels] = None,
|
||||
activation: Optional[Activation] = None,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for image segmenter.
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Example usage:
|
||||
metadata_writer = image_segmenter.Metadatawriter.create(model_buffer, ...)
|
||||
tflite_content, json_content = metadata_writer.populate()
|
||||
|
||||
When calling `populate` function in this class, it returns TfLite content
|
||||
and JSON content. Note that only the output TFLite is used for deployment.
|
||||
The output JSON content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
input_norm_mean: the mean value used in the input tensor normalization
|
||||
[1].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [1].
|
||||
labels: an instance of Labels helper class used in the output category
|
||||
tensor [2].
|
||||
activation: activation function for the output layer.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_segmentation_output(labels=labels)
|
||||
if activation is not None:
|
||||
option_md = ImageSegmenterOptionsMd(activation)
|
||||
writer.add_custom_metadata(option_md)
|
||||
return cls(writer)
|
||||
|
||||
def populate(self) -> tuple[bytearray, str]:
|
||||
model_buf, _ = super().populate()
|
||||
metadata_buf = metadata.get_metadata_buffer(model_buf)
|
||||
json_content = convert_to_json(metadata_buf)
|
||||
return model_buf, json_content
|
|
@ -1030,6 +1030,52 @@ class TensorGroupMd:
|
|||
return group
|
||||
|
||||
|
||||
class SegmentationMaskMd(TensorMd):
|
||||
"""A container for the segmentation mask metadata information."""
|
||||
|
||||
# The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where
|
||||
# N is the number of objects that the segmentation model can recognize. The
|
||||
# output tensor is essentially a list of grayscale bitmaps, where each value
|
||||
# is the probability of the corresponding pixel belonging to a certain object
|
||||
# type. Therefore, the content dimension range of the output tensor is [1, 2].
|
||||
_CONTENT_DIM_MIN = 1
|
||||
_CONTENT_DIM_MAX = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
associated_files = label_files or []
|
||||
super().__init__(
|
||||
name=name, description=description, associated_files=associated_files
|
||||
)
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
|
||||
"""Creates the metadata for the segmentation masks tensor."""
|
||||
masks_metadata = super().create_metadata()
|
||||
|
||||
# Create tensor content information.
|
||||
content = _metadata_fb.ContentT()
|
||||
content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE
|
||||
content.contentPropertiesType = (
|
||||
_metadata_fb.ContentProperties.ImageProperties
|
||||
)
|
||||
# Add the content range. See
|
||||
# https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L323-L385
|
||||
dim_range = _metadata_fb.ValueRangeT()
|
||||
dim_range.min = self._CONTENT_DIM_MIN
|
||||
dim_range.max = self._CONTENT_DIM_MAX
|
||||
content.range = dim_range
|
||||
masks_metadata.content = content
|
||||
|
||||
return masks_metadata
|
||||
|
||||
|
||||
class CustomMetadataMd(abc.ABC):
|
||||
"""An abstract class of a container for the custom metadata information."""
|
||||
|
||||
|
|
|
@ -34,6 +34,10 @@ _INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
|
|||
'text to be processed.')
|
||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||
_OUTPUT_SEGMENTATION_MASKS_NAME = 'segmentation_masks'
|
||||
_OUTPUT_SEGMENTATION_MASKS_DESCRIPTION = (
|
||||
'Masks over the target objects with high accuracy.'
|
||||
)
|
||||
# Detection tensor result to be grouped together.
|
||||
_DETECTION_GROUP_NAME = 'detection_result'
|
||||
# File name to export score calibration parameters.
|
||||
|
@ -657,6 +661,32 @@ class MetadataWriter(object):
|
|||
self._output_group_mds.append(group_md)
|
||||
return self
|
||||
|
||||
def add_segmentation_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
name: str = _OUTPUT_SEGMENTATION_MASKS_NAME,
|
||||
description: str = _OUTPUT_SEGMENTATION_MASKS_DESCRIPTION,
|
||||
) -> 'MetadataWriter':
|
||||
"""Adds a segmentation head metadata for segmentation output tensor.
|
||||
|
||||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
name: Metadata name of the tensor. Note that this is different from tensor
|
||||
name in the flatbuffer.
|
||||
description: human readable description of what the output is.
|
||||
|
||||
Returns:
|
||||
The current Writer instance to allow chained operation.
|
||||
"""
|
||||
label_files = self._create_label_file_md(labels)
|
||||
output_md = metadata_info.SegmentationMaskMd(
|
||||
name=name,
|
||||
description=description,
|
||||
label_files=label_files,
|
||||
)
|
||||
self._output_mds.append(output_md)
|
||||
return self
|
||||
|
||||
def add_feature_output(self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None) -> 'MetadataWriter':
|
||||
|
|
|
@ -91,3 +91,18 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_segmenter_test",
|
||||
srcs = ["image_segmenter_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:image_segmenter",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer.image_segmenter."""
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_segmenter
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
_MODEL_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_metadata.tflite")
|
||||
)
|
||||
_LABEL_FILE_NAME = "labels.txt"
|
||||
_LABEL_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "segmenter_labelmap.txt")
|
||||
)
|
||||
_NORM_MEAN = 127.5
|
||||
_NORM_STD = 127.5
|
||||
_JSON_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3.json")
|
||||
)
|
||||
_JSON_FILE_WITHOUT_LABELS = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_labels.json")
|
||||
)
|
||||
_JSON_FILE_WITH_ACTIVATION = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_with_activation.json")
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmenterTest(absltest.TestCase):
|
||||
|
||||
def test_write_metadata(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
)
|
||||
tflite_content, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
label_file_buffer = displayer.get_associated_file_buffer(_LABEL_FILE_NAME)
|
||||
with open(_LABEL_FILE, "rb") as f:
|
||||
expected_labelfile_buffer = f.read()
|
||||
self.assertEqual(label_file_buffer, expected_labelfile_buffer)
|
||||
|
||||
def test_write_metadata_without_labels(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE_WITHOUT_LABELS, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_with_activation(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
activation=image_segmenter.Activation.SIGMOID,
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE_WITH_ACTIVATION, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
|
@ -455,6 +455,27 @@ class TensorGroupMdMdTest(absltest.TestCase):
|
|||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class SegmentationMaskMdTest(absltest.TestCase):
|
||||
_NAME = "segmentation_masks"
|
||||
_DESCRIPTION = "Masks over the target objects."
|
||||
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "segmentation_mask_meta.json")
|
||||
)
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
segmentation_mask_md = metadata_info.SegmentationMaskMd(
|
||||
name=self._NAME, description=self._DESCRIPTION
|
||||
)
|
||||
metadata = segmentation_mask_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_tensor(metadata)
|
||||
)
|
||||
with open(self._EXPECTED_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
|
@ -42,48 +42,62 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
|||
|
||||
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
|
||||
_IMAGE_FILE = 'cats_and_dogs.jpg'
|
||||
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=608, origin_y=161, width=381, height=439),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.69921875,
|
||||
display_name=None,
|
||||
category_name='cat')
|
||||
]),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=60, origin_y=398, width=386, height=196),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.64453125,
|
||||
display_name=None,
|
||||
category_name='cat')
|
||||
]),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=256, origin_y=395, width=173, height=202),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.51171875,
|
||||
display_name=None,
|
||||
category_name='cat')
|
||||
]),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=362, origin_y=191, width=325, height=419),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.48828125,
|
||||
display_name=None,
|
||||
category_name='cat')
|
||||
])
|
||||
])
|
||||
_EXPECTED_DETECTION_RESULT = _DetectionResult(
|
||||
detections=[
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=608, origin_y=161, width=381, height=439
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.69921875,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=60, origin_y=398, width=386, height=196
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.64453125,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=256, origin_y=395, width=173, height=202
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.51171875,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=362, origin_y=191, width=325, height=419
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.48828125,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
_ALLOW_LIST = ['cat', 'dog']
|
||||
_DENY_LIST = ['cat']
|
||||
_SCORE_THRESHOLD = 0.3
|
||||
|
|
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -28,6 +28,10 @@ mediapipe_files(srcs = [
|
|||
"category_tensor_float_meta.json",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_score_calibration.json",
|
||||
"deeplabv3.json",
|
||||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"labelmap.txt",
|
||||
|
@ -44,6 +48,8 @@ mediapipe_files(srcs = [
|
|||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
"movie_review.tflite",
|
||||
"score_calibration.csv",
|
||||
"segmentation_mask_meta.json",
|
||||
"segmenter_labelmap.txt",
|
||||
"ssd_mobilenet_v1_no_metadata.json",
|
||||
"ssd_mobilenet_v1_no_metadata.tflite",
|
||||
"tensor_group_meta.json",
|
||||
|
@ -87,6 +93,7 @@ filegroup(
|
|||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
|
@ -116,6 +123,9 @@ filegroup(
|
|||
"classification_tensor_uint8_meta.json",
|
||||
"classification_tensor_unsupported_meta.json",
|
||||
"coco_ssd_mobilenet_v1_score_calibration.json",
|
||||
"deeplabv3.json",
|
||||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"external_file",
|
||||
"feature_tensor_meta.json",
|
||||
|
@ -140,6 +150,8 @@ filegroup(
|
|||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"segmentation_mask_meta.json",
|
||||
"segmenter_labelmap.txt",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
"ssd_mobilenet_v1_no_metadata.json",
|
||||
"tensor_group_meta.json",
|
||||
|
|
66
mediapipe/tasks/testdata/metadata/deeplabv3.json
vendored
Normal file
66
mediapipe/tasks/testdata/metadata/deeplabv3.json
vendored
Normal file
|
@ -0,0 +1,66 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
67
mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json
vendored
Normal file
67
mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {}
|
||||
}
|
||||
],
|
||||
"custom_metadata": [
|
||||
{
|
||||
"name": "SEGMENTER_METADATA",
|
||||
"data": {
|
||||
"activation": "SIGMOID"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.5.0"
|
||||
}
|
59
mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json
vendored
Normal file
59
mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json
vendored
Normal file
|
@ -0,0 +1,59 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
24
mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json
vendored
Normal file
24
mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
21
mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt
vendored
Normal file
21
mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
background
|
||||
aeroplane
|
||||
bicycle
|
||||
bird
|
||||
boat
|
||||
bottle
|
||||
bus
|
||||
car
|
||||
cat
|
||||
chair
|
||||
cow
|
||||
dining table
|
||||
dog
|
||||
horse
|
||||
motorbike
|
||||
person
|
||||
potted plant
|
||||
sheep
|
||||
sofa
|
||||
train
|
||||
tv
|
9
mediapipe/tasks/testdata/vision/BUILD
vendored
9
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -70,6 +70,8 @@ mediapipe_files(srcs = [
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"pose.jpg",
|
||||
"pose_detection.tflite",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
@ -78,6 +80,8 @@ mediapipe_files(srcs = [
|
|||
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||
"selfie_segm_144_256_3.tflite",
|
||||
"selfie_segm_144_256_3_expected_mask.jpg",
|
||||
"selfie_segmentation.tflite",
|
||||
"selfie_segmentation_landscape.tflite",
|
||||
"thumb_up.jpg",
|
||||
"victory.jpg",
|
||||
])
|
||||
|
@ -125,6 +129,7 @@ filegroup(
|
|||
"portrait.jpg",
|
||||
"portrait_hair_expected_mask.jpg",
|
||||
"portrait_rotated.jpg",
|
||||
"pose.jpg",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
|
@ -170,8 +175,11 @@ filegroup(
|
|||
"mobilenet_v2_1.0_224.tflite",
|
||||
"mobilenet_v3_small_100_224_embedder.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
"pose_detection.tflite",
|
||||
"selfie_segm_128_128_3.tflite",
|
||||
"selfie_segm_144_256_3.tflite",
|
||||
"selfie_segmentation.tflite",
|
||||
"selfie_segmentation_landscape.tflite",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -197,6 +205,7 @@ filegroup(
|
|||
"portrait_expected_face_landmarks.pbtxt",
|
||||
"portrait_expected_face_landmarks_with_attention.pbtxt",
|
||||
"portrait_rotated_expected_detection.pbtxt",
|
||||
"pose_expected_detection.pbtxt",
|
||||
"thumb_up_landmarks.pbtxt",
|
||||
"thumb_up_rotated_landmarks.pbtxt",
|
||||
"victory_landmarks.pbtxt",
|
||||
|
|
27
mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt
vendored
Normal file
27
mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
# proto-file: mediapipe/framework/formats/detection.proto
|
||||
# proto-message: Detection
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box {
|
||||
xmin: 397
|
||||
ymin: 198
|
||||
width: 199
|
||||
height: 199
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.4879558
|
||||
y: 0.7013345
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.48453212
|
||||
y: 0.32265592
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.4992165
|
||||
y: 0.4854874
|
||||
}
|
||||
relative_keypoints {
|
||||
x: 0.50227845
|
||||
y: 0.159788
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ VISION_LIBS = [
|
|||
"//mediapipe/tasks/web/vision/image_classifier",
|
||||
"//mediapipe/tasks/web/vision/image_embedder",
|
||||
"//mediapipe/tasks/web/vision/image_segmenter",
|
||||
"//mediapipe/tasks/web/vision/interactive_segmenter",
|
||||
"//mediapipe/tasks/web/vision/object_detector",
|
||||
]
|
||||
|
||||
|
|
|
@ -75,6 +75,24 @@ imageSegmenter.segment(image, (masks, width, height) => {
|
|||
});
|
||||
```
|
||||
|
||||
## Interactive Segmentation
|
||||
|
||||
The MediaPipe Interactive Segmenter lets you select a region of interest to
|
||||
segment an image by.
|
||||
|
||||
```
|
||||
const vision = await FilesetResolver.forVisionTasks(
|
||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||
);
|
||||
const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath(
|
||||
vision, "model.tflite"
|
||||
);
|
||||
const image = document.getElementById("image") as HTMLImageElement;
|
||||
interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } },
|
||||
(masks, width, height) => { ... }
|
||||
);
|
||||
```
|
||||
|
||||
## Object Detection
|
||||
|
||||
The MediaPipe Object Detector task lets you detect the presence and location of
|
||||
|
|
|
@ -20,6 +20,7 @@ import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/ha
|
|||
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
import {InteractiveSegmenter as InteractiveSegmenterImpl} from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
|
||||
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||
|
@ -30,6 +31,7 @@ const HandLandmarker = HandLandmarkerImpl;
|
|||
const ImageClassifier = ImageClassifierImpl;
|
||||
const ImageEmbedder = ImageEmbedderImpl;
|
||||
const ImageSegmenter = ImageSegementerImpl;
|
||||
const InteractiveSegmenter = InteractiveSegmenterImpl;
|
||||
const ObjectDetector = ObjectDetectorImpl;
|
||||
|
||||
export {
|
||||
|
@ -39,5 +41,6 @@ export {
|
|||
ImageClassifier,
|
||||
ImageEmbedder,
|
||||
ImageSegmenter,
|
||||
InteractiveSegmenter,
|
||||
ObjectDetector
|
||||
};
|
||||
|
|
62
mediapipe/tasks/web/vision/interactive_segmenter/BUILD
Normal file
62
mediapipe/tasks/web/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,62 @@
|
|||
# This contains the MediaPipe Interactive Segmenter Task.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "interactive_segmenter",
|
||||
srcs = ["interactive_segmenter.ts"],
|
||||
deps = [
|
||||
":interactive_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:keypoint",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:types",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/util:color_jspb_proto",
|
||||
"//mediapipe/util:render_data_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_declaration(
|
||||
name = "interactive_segmenter_types",
|
||||
srcs = ["interactive_segmenter_options.d.ts"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "interactive_segmenter_test_lib",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"interactive_segmenter_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":interactive_segmenter",
|
||||
":interactive_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
"//mediapipe/util:render_data_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "interactive_segmenter_test",
|
||||
tags = ["nomsan"],
|
||||
deps = [":interactive_segmenter_test_lib"],
|
||||
)
|
|
@ -0,0 +1,306 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
|
||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
||||
|
||||
export * from './interactive_segmenter_options';
|
||||
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
|
||||
export {ImageSource};
|
||||
|
||||
const IMAGE_IN_STREAM = 'image_in';
|
||||
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||
const ROI_IN_STREAM = 'roi_in';
|
||||
const IMAGE_OUT_STREAM = 'image_out';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on images.
|
||||
*
|
||||
* Users can represent user interaction through `RegionOfInterest`, which gives
|
||||
* a hint to InteractiveSegmenter to perform segmentation focusing on the given
|
||||
* region of interest.
|
||||
*
|
||||
* The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||
*
|
||||
* Input tensor:
|
||||
* (kTfLiteUInt8/kTfLiteFloat32)
|
||||
* - image input of size `[batch x height x width x channels]`.
|
||||
* - batch inference is not supported (`batch` is required to be 1).
|
||||
* - RGB inputs is supported (`channels` is required to be 3).
|
||||
* - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
* attached to the metadata for input normalization.
|
||||
* Output tensors:
|
||||
* (kTfLiteUInt8/kTfLiteFloat32)
|
||||
* - list of segmented masks.
|
||||
* - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
* - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
* `channels`.
|
||||
* - batch is always 1
|
||||
*/
|
||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter from
|
||||
* the provided options.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param interactiveSegmenterOptions The options for the Interactive
|
||||
* Segmenter. Note that either a path to the model asset or a model buffer
|
||||
* needs to be provided (via `baseOptions`).
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromOptions(
|
||||
wasmFileset: WasmFileset,
|
||||
interactiveSegmenterOptions: InteractiveSegmenterOptions):
|
||||
Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
interactiveSegmenterOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter based
|
||||
* on the provided model asset buffer.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetBuffer A binary representation of the model.
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromModelBuffer(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetBuffer: Uint8Array): Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetBuffer}});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter based
|
||||
* on the path to the model asset.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetPath The path to the model asset.
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromModelPath(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetPath: string): Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetPath}});
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(
|
||||
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_IN_STREAM,
|
||||
NORM_RECT_IN_STREAM, /* roiAllowed= */ false);
|
||||
this.options = new ImageSegmenterGraphOptionsProto();
|
||||
this.segmenterOptions = new SegmenterOptionsProto();
|
||||
this.options.setSegmenterOptions(this.segmenterOptions);
|
||||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
}
|
||||
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the interactive segmenter.
|
||||
*
|
||||
* Calling `setOptions()` with a subset of options only affects those
|
||||
* options. You can reset an option back to its default value by
|
||||
* explicitly setting it to `undefined`.
|
||||
*
|
||||
* @param options The options for the interactive segmenter.
|
||||
* @return A Promise that resolves when the settings have been applied.
|
||||
*/
|
||||
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
* user's region of interest for segmentation.
|
||||
*
|
||||
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
|
||||
* of images that represent per-category segmented image mask. If the
|
||||
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
|
||||
* images that contains only one confidence image mask. The method returns
|
||||
* synchronously once the callback returns.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param roi The region of interest for segmentation.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
* user's region of interest for segmentation.
|
||||
*
|
||||
* The 'image_processing_options' parameter can be used to specify the
|
||||
* rotation to apply to the image before performing segmentation, by setting
|
||||
* its 'rotationDegrees' field. Note that specifying a region-of-interest
|
||||
* using the 'regionOfInterest' field is NOT supported and will result in an
|
||||
* error.
|
||||
*
|
||||
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
|
||||
* of images that represent per-category segmented image mask. If the
|
||||
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
|
||||
* images that contains only one confidence image mask. The method returns
|
||||
* synchronously once the callback returns.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param roi The region of interest for segmentation.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
||||
graphConfig.addInputStream(ROI_IN_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
||||
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
ImageSegmenterGraphOptionsProto.ext, this.options);
|
||||
|
||||
const segmenterNode = new CalculatorGraphConfig.Node();
|
||||
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
|
||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
||||
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
||||
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
IMAGE_OUT_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the user-facing RegionOfInterest message to the RenderData proto
|
||||
* and sends it to the graph
|
||||
*/
|
||||
private processRenderData(roi: RegionOfInterest, timestamp: number): void {
|
||||
const renderData = new RenderDataProto();
|
||||
|
||||
const renderAnnotation = new RenderAnnotationProto();
|
||||
|
||||
const color = new ColorProto();
|
||||
color.setR(255);
|
||||
renderAnnotation.setColor(color);
|
||||
|
||||
const point = new RenderAnnotationProto.Point();
|
||||
point.setNormalized(true);
|
||||
point.setX(roi.keypoint.x);
|
||||
point.setY(roi.keypoint.y);
|
||||
renderAnnotation.setPoint(point);
|
||||
|
||||
renderData.addRenderAnnotations(renderAnnotation);
|
||||
|
||||
this.graphRunner.addProtoToStream(
|
||||
renderData.serializeBinary(), 'mediapipe.RenderData', ROI_IN_STREAM,
|
||||
timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
|
36
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts
vendored
Normal file
36
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||
|
||||
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
||||
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter';
|
||||
|
||||
|
||||
const ROI: RegionOfInterest = {
|
||||
keypoint: {x: 0.1, y: 0.2}
|
||||
};
|
||||
|
||||
class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||
MediapipeTasksFake {
|
||||
calculatorName =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
attachListenerSpies: jasmine.Spy[] = [];
|
||||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('image_out');
|
||||
this.imageVectorListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||
|
||||
spyOn(this.graphRunner, 'addProtoToStream')
|
||||
.and.callFake((data, protoName, stream) => {
|
||||
if (stream === 'roi_in') {
|
||||
expect(protoName).toEqual('mediapipe.RenderData');
|
||||
this.lastRoi = RenderDataProto.deserializeBinary(data);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
describe('InteractiveSegmenter', () => {
|
||||
let interactiveSegmenter: InteractiveSegmenterFake;
|
||||
|
||||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
interactiveSegmenter = new InteractiveSegmenterFake();
|
||||
await interactiveSegmenter.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(interactiveSegmenter);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||
await interactiveSegmenter.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: newModel,
|
||||
}
|
||||
});
|
||||
|
||||
verifyGraph(
|
||||
interactiveSegmenter,
|
||||
/* expectedCalculatorOptions= */ undefined,
|
||||
/* expectedBaseOptions= */
|
||||
[
|
||||
'modelAsset', {
|
||||
fileContent: newModelBase64,
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
|
||||
describe('setOptions()', () => {
|
||||
const fieldPath = ['segmenterOptions', 'outputType'];
|
||||
|
||||
it(`can set outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
});
|
||||
|
||||
it(`can clear outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
await interactiveSegmenter.setOptions({outputType: undefined});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
|
||||
});
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI,
|
||||
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
|
||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||
});
|
||||
|
||||
it('sends region-of-interest', (done) => {
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(interactiveSegmenter.lastRoi).toBeDefined();
|
||||
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
|
||||
.toEqual(jasmine.objectContaining({
|
||||
color: {r: 255, b: undefined, g: undefined},
|
||||
}));
|
||||
done();
|
||||
});
|
||||
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
const mask = new Uint8Array([1, 2, 3, 4]);
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
});
|
||||
});
|
||||
|
||||
it('supports confidence masks', async () => {
|
||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -20,4 +20,5 @@ export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
|||
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
export * from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
|
||||
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
|
6
third_party/BUILD
vendored
6
third_party/BUILD
vendored
|
@ -169,7 +169,11 @@ cmake_external(
|
|||
"-lm",
|
||||
"-lpthread",
|
||||
"-lrt",
|
||||
],
|
||||
] + select({
|
||||
"//mediapipe:ios": ["-framework Cocoa"],
|
||||
"//mediapipe:macos": ["-framework Cocoa"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
shared_libraries = select({
|
||||
"@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES],
|
||||
# Only the shared objects listed here will be linked in the directory
|
||||
|
|
86
third_party/external_files.bzl
vendored
86
third_party/external_files.bzl
vendored
|
@ -72,8 +72,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_BUILD_orig",
|
||||
sha256 = "64d5343a6a5f9be06db0a5074a2260f9ae63a989fe01702832cd215680dc19c1",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678323576393653"],
|
||||
sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -208,10 +208,34 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/corrupted_mobilenet_v1_0.25_224_1_default_1.tflite?generation=1661875706780536"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_json",
|
||||
sha256 = "f299835bd9ea1cceb25fdf40a761a22716cbd20025cd67c365a860527f178b7f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.json?generation=1678818040715103"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_tflite",
|
||||
sha256 = "9711334db2b01d5894feb8ed0f5cb3e97d125b8d229f8d8692f625801818f5ef",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"],
|
||||
sha256 = "5faed2c653905d3e22a8f6f29ee198da84e9b0e7936a207bf431f17f6b4d87ff",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1678775085237701"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_with_activation_json",
|
||||
sha256 = "a7633476d02f970db3cc30f5f027bcb608149e02207b2ccae36a4b69d730c82c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_with_activation.json?generation=1678818047050984"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_without_labels_json",
|
||||
sha256 = "7d045a583a4046f17a52d2078b0175607a45ed0cc187558325f9c66534c08401",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_labels.json?generation=1678818050191996"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_without_metadata_tflite",
|
||||
sha256 = "68a539782c2c6a72f8aac3724600124a85ed977162b44e84cbae5db717c933c6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_metadata.tflite?generation=1678818053623010"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -390,8 +414,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hair_segmentation_tflite",
|
||||
sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"],
|
||||
sha256 = "7cbddcfe6f6e10c3e0a509eb2e14225fda5c0de6c35e2e8c6ca8e3971988fc17",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678775089064550"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -823,7 +847,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
|
||||
sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678737486927530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -864,8 +888,20 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_detection_tflite",
|
||||
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1661875889147923"],
|
||||
sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1678737489600422"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_expected_detection_pbtxt",
|
||||
sha256 = "e0d40e98dd5320a780a642c336d0c8720243ac5bcc0e39c4061ad970a503ae24",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_jpg",
|
||||
sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -964,6 +1000,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_input_rotation0.jpg?generation=1661875914048401"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmentation_mask_meta_json",
|
||||
sha256 = "4294d53b309c1fbe38a5184de4057576c3dec14e07d16491f1dd459ac9116ab3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_mask_meta.json?generation=1678818065134737"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmenter_labelmap_txt",
|
||||
sha256 = "d9efa78274f1799ddbcab1f87263e19dae338c1697de47a5b270c9526c45d364",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmenter_labelmap.txt?generation=1678818068181025"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg",
|
||||
sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e",
|
||||
|
@ -972,8 +1020,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_128_128_3_tflite",
|
||||
sha256 = "bb154f248543c0738e32f1c74375245651351a84746dc21f10bdfaabd8fae4ca",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1661875919964123"],
|
||||
sha256 = "8322982866488b063af6531b1d16ac27c7bf404135b7905f20aaf5e6af7aa45b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1678775097370282"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -984,20 +1032,20 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_144_256_3_tflite",
|
||||
sha256 = "5c770b8834ad50586599eae7710921be09d356898413fc0bf37a9458da0610eb",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1661875925519713"],
|
||||
sha256 = "f16a9551a408edeadd53f70d1d2911fc20f9f9de7a394129a268ca9faa2d6a08",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1678775099616375"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segmentation_landscape_tflite",
|
||||
sha256 = "4aafe6223bb8dac6fac8ca8ed56852870a33051ef3f6238822d282a109962894",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1661875928328455"],
|
||||
sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segmentation_tflite",
|
||||
sha256 = "8d13b7fae74af625c641226813616a2117bd6bca19eb3b75574621fc08557f27",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
|
||||
sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -1224,8 +1272,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_object_detection_saved_model_README_md",
|
||||
sha256 = "fe163cf12fbd017738a2fd360c03d223e964ba6404ac75c635f5918784e9c34d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1661875995856372"],
|
||||
sha256 = "acc23dee09f69210717ac060035c844ba902e8271486f1086f29fb156c236690",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1678737498915254"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user