Merge branch 'google:master' into face-landmarker-python

This commit is contained in:
Kinar R 2023-03-15 11:03:15 +05:30 committed by GitHub
commit 647db21fc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 3290 additions and 234 deletions

View File

@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
# Node dependencies # Node dependencies
http_archive( http_archive(
name = "build_bazel_rules_nodejs", name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda", sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"], urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
) )
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies") load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")

View File

@ -1270,6 +1270,50 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "flat_color_image_calculator_proto",
srcs = ["flat_color_image_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
],
)
cc_library(
name = "flat_color_image_calculator",
srcs = ["flat_color_image_calculator.cc"],
deps = [
":flat_color_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/util:color_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "flat_color_image_calculator_test",
srcs = ["flat_color_image_calculator_test.cc"],
deps = [
":flat_color_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:color_cc_proto",
],
)
cc_library( cc_library(
name = "from_image_calculator", name = "from_image_calculator",
srcs = ["from_image_calculator.cc"], srcs = ["from_image_calculator.cc"],

View File

@ -0,0 +1,138 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
} // namespace
// A calculator for generating an image filled with a single color.
//
// Inputs:
// IMAGE (Image, optional)
// If provided, the output will have the same size
// COLOR (Color proto, optional)
// Color to paint the output with. Takes precedence over the equivalent
// calculator options.
//
// Outputs:
// IMAGE (Image)
// Image filled with the requested color.
//
// Example useage:
// node {
// calculator: "FlatColorImageCalculator"
// input_stream: "IMAGE:image"
// input_stream: "COLOR:color"
// output_stream: "IMAGE:blank_image"
// options {
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
// color: {
// r: 255
// g: 255
// b: 255
// }
// }
// }
// }
class FlatColorImageCalculator : public Node {
public:
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
static constexpr Input<Color>::Optional kInColor{"COLOR"};
static constexpr Output<Image> kOutImage{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
static absl::Status UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
RET_CHECK(kInImage(cc).IsConnected() ^
(options.has_output_height() || options.has_output_width()))
<< "Either set IMAGE input stream, or set through options";
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
<< "Either set COLOR input stream, or set through options";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
bool use_dimension_from_option_ = false;
bool use_color_from_option_ = false;
};
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
use_dimension_from_option_ = !kInImage(cc).IsConnected();
use_color_from_option_ = !kInColor(cc).IsConnected();
return absl::OkStatus();
}
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
int output_height = -1;
int output_width = -1;
if (use_dimension_from_option_) {
output_height = options.output_height();
output_width = options.output_width();
} else if (!kInImage(cc).IsEmpty()) {
const Image& input_image = kInImage(cc).Get();
output_height = input_image.height();
output_width = input_image.width();
} else {
return absl::OkStatus();
}
Color color;
if (use_color_from_option_) {
color = options.color();
} else if (!kInColor(cc).IsEmpty()) {
color = kInColor(cc).Get();
} else {
return absl::OkStatus();
}
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
output_width, output_height);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
kOutImage(cc).Send(Image(output_frame));
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message FlatColorImageCalculatorOptions {
extend CalculatorOptions {
optional FlatColorImageCalculatorOptions ext = 515548435;
}
// Output dimensions.
optional int32 output_width = 1;
optional int32 output_height = 2;
// The color to fill with in the output image.
optional Color color = 3;
}

View File

@ -0,0 +1,210 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::testing::HasSubstr;
constexpr char kImageTag[] = "IMAGE";
constexpr char kColorTag[] = "COLOR";
constexpr int kImageWidth = 256;
constexpr int kImageHeight = 256;
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), kImageWidth);
EXPECT_EQ(image.height(), kImageHeight);
auto image_frame = image.GetImageFrameSharedPtr();
auto* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 100);
EXPECT_EQ(pixel_data[1], 200);
EXPECT_EQ(pixel_data[2], 255);
}
}
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), 7);
EXPECT_EQ(image.height(), 13);
auto image_frame = image.GetImageFrameSharedPtr();
const uint8_t* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 0);
EXPECT_EQ(pixel_data[1], 5);
EXPECT_EQ(pixel_data[2], 0);
}
}
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
} // namespace
} // namespace mediapipe

View File

@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
'import public "' + join_path + '";', 'import public "' + join_path + '";',
) )
rewrite_ref = SubsituteCommand( rewrite_ref = SubsituteCommand(
r"mediapipe\\.(" + rewrite_message_regex + ")", r"mediapipe\.(" + rewrite_message_regex + ")",
r"mediapipe.\\1", r"mediapipe.\\1",
) )
rewrite_objc = SubsituteCommand( rewrite_objc = SubsituteCommand(

View File

@ -467,6 +467,7 @@ cc_library(
"//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:frame_buffer",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:yuv_image", "//mediapipe/framework/formats:yuv_image",
"//mediapipe/util/frame_buffer:frame_buffer_util",
"//third_party/libyuv", "//third_party/libyuv",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization."""

View File

@ -0,0 +1,98 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer dataset library."""
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
# TODO: Change to a unlabeled dataset if it makes sense.
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
)

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# TODO: Replace the stylize image dataset with licensed images.
self._test_data_dirname = 'testdata'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

View File

@ -57,6 +57,7 @@ pybind_extension(
"//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:landmark_registration",
"//mediapipe/framework/formats:rect_registration", "//mediapipe/framework/formats:rect_registration",
"//mediapipe/modules/objectron/calculators:annotation_registration", "//mediapipe/modules/objectron/calculators:annotation_registration",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
], ],
) )

View File

@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
// //
// At runtime, such codes are meant to be attached (where applicable) to a // At runtime, such codes are meant to be attached (where applicable) to a
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and // `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
// stringifed error code as value (aka payload). This logic is encapsulated in // stringified error code as value (aka payload). This logic is encapsulated in
// the `CreateStatusWithPayload` helper below for convenience. // the `CreateStatusWithPayload` helper below for convenience.
// //
// The returned status includes: // The returned status includes:

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique( auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto()); model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources; return model_bundle_resources;
} }
absl::Status absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) { if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file // If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success. // in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data(); model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size = size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size(); model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
&model_files_);
} }
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile( absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const { const std::string& filename) const {
auto it = model_files_.find(filename); auto it = files_.find(filename);
if (it == model_files_.end()) { if (it == files_.end()) {
auto model_files = ListModelFiles(); auto files = ListFiles();
std::string all_model_files = std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
return CreateStatusWithPayload( return CreateStatusWithPayload(
StatusCode::kNotFound, StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the " absl::StrFormat("No file with name: %s. All files in the model asset "
"model asset bundle are: %s.", "bundle are: %s.",
filename, all_model_files), filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError); MediaPipeTasksStatus::kFileNotFoundError);
} }
return it->second; return it->second;
} }
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const { std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> model_names; std::vector<std::string> file_names;
for (const auto& [model_name, _] : model_files_) { for (const auto& [file_name, _] : files_) {
model_names.push_back(model_name); file_names.push_back(file_name);
} }
return model_names; return file_names;
} }
} // namespace core } // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class. // The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto, // A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the // contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the // tflite models, resource files or model asset bundles for the mediapipe
// resources are owned by the ModelAssetBundleResources object // sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the // callers must keep ModelAssetBundleResources alive while using any of the
// resources. // resources.
class ModelAssetBundleResources { class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag. // Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; } std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model // Gets the contents of the model file (either tflite model file, resource
// bundle file) with the provided name. An error is returned if there is no // file or model bundle file) with the provided name. An error is returned if
// such model file. // there is no such model file.
absl::StatusOr<absl::string_view> GetModelFile( absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
const std::string& filename) const;
// Lists all the model file names in the model asset model. // Lists all the file names in the model asset model.
std::vector<std::string> ListModelFiles() const; std::vector<std::string> ListFiles() const;
private: private:
// Constructor. // Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag, const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file); std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file) // Extracts the model files (either tflite model file, resource file or model
// from the external file proto. // bundle file) from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto(); absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag. // The model asset bundle resources tag.
const std::string tag_; const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle. // The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_; std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename // The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and // (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either // a pointer to the file contents as value. Each file can be either a TFLite
// a TFLite model file or a model bundle file for sub-task. // model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_; absl::flat_hash_map<std::string, absl::string_view> files_;
}; };
} // namespace core } // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
#endif // _WIN32 #endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources. // Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file))); std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite") ->GetFile("dummy_hand_detector.tflite")
.status()); .status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite") ->GetFile("dummy_hand_landmarker.tflite")
.status()); .status());
} }
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works. // Verify tflite model works.
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status(); auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(), EXPECT_THAT(
testing::HasSubstr( status.message(),
"No model file with name: not_found.task. All model files in " testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: ")); "the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord( testing::Optional(absl::Cord(
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles(); auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = { std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end()); std::sort(model_files.begin(), model_files.end());

View File

@ -64,7 +64,7 @@ class ModelMetadataPopulator {
// Loads associated files into the TFLite FlatBuffer model. The input is a map // Loads associated files into the TFLite FlatBuffer model. The input is a map
// of {filename, file contents}. // of {filename, file contents}.
// //
// Warning: this method removes any previoulsy present associated files. // Warning: this method removes any previously present associated files.
// Calling this method multiple time removes any associated files from // Calling this method multiple time removes any associated files from
// previous calls, so this method should usually be called only once. // previous calls, so this method should usually be called only once.
void LoadAssociatedFiles( void LoadAssociatedFiles(

View File

@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
// Using pybind11 type conversions to convert between Python and native // Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types // C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details. // in C++ and vice versa. See the pybind 11 instruction [1] for more details.
// Type converstions is recommended by pybind11, though the main downside // Type conversions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition: // is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally // this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout. // wont have the same memory layout.

View File

@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata); FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
builder.Finish(metadata); builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error. // Gets the minimum metadata parser version and triggers error.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),

View File

@ -34,7 +34,7 @@ constexpr char kTestSPModelPath[] =
std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer( std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer(
absl::string_view model_path) { absl::string_view model_path) {
// We are using `LoadBinaryContent()` instead of loading the model direclty // We are using `LoadBinaryContent()` instead of loading the model directly
// via `SentencePieceTokenizer` so that the file can be located on Windows // via `SentencePieceTokenizer` so that the file can be located on Windows
std::string buffer = LoadBinaryContent(kTestSPModelPath); std::string buffer = LoadBinaryContent(kTestSPModelPath);
return absl::make_unique<SentencePieceTokenizer>(buffer.data(), return absl::make_unique<SentencePieceTokenizer>(buffer.data(),

View File

@ -60,6 +60,7 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:external_file_handler", "//mediapipe/tasks/cc/core:external_file_handler",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
@ -69,6 +70,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -18,12 +18,14 @@
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h" #include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
@ -41,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT";
static constexpr char kImageSizeTag[] = "IMAGE_SIZE"; static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY"; static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS"; static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY";
static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS";
using ::mediapipe::tasks::vision::face_geometry::proto::Environment; using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry; using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
using ::mediapipe::tasks::vision::face_geometry::proto:: using ::mediapipe::tasks::vision::face_geometry::proto::
GeometryPipelineMetadata; GeometryPipelineMetadata;
// A calculator that renders a visual effect for multiple faces. absl::Status SanityCheck(CalculatorContract* cc) {
if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^
cc->Inputs().HasTag(kMultiFaceLandmarksTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceLandmarksTag, kMultiFaceLandmarksTag));
}
if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^
cc->Outputs().HasTag(kMultiFaceGeometryTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceGeometryTag, kMultiFaceGeometryTag));
}
if (cc->Inputs().HasTag(kFaceLandmarksTag) !=
cc->Outputs().HasTag(kFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kFaceLandmarksTag, kFaceGeometryTag));
}
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) !=
cc->Outputs().HasTag(kMultiFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kMultiFaceLandmarksTag, kMultiFaceGeometryTag));
}
return absl::OkStatus();
}
// A calculator that renders a visual effect for multiple faces. Support single
// face landmarks or multiple face landmarks.
// //
// Inputs: // Inputs:
// IMAGE_SIZE (`std::pair<int, int>`, required): // IMAGE_SIZE (`std::pair<int, int>`, required):
@ -58,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// ratio. If used as-is, the resulting face geometry visualization should be // ratio. If used as-is, the resulting face geometry visualization should be
// happening on a frame with the same ratio as well. // happening on a frame with the same ratio as well.
// //
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, required): // MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, optional):
// A vector of face landmark lists. // A vector of face landmark lists. If connected, the output stream
// MULTI_FACE_GEOMETRY must be connected.
// FACE_LANDMARKS (NormalizedLandmarkList, optional):
// A NormalizedLandmarkList of single face landmark lists. If connected, the
// output stream FACE_GEOMETRY must be connected.
// //
// Input side packets: // Input side packets:
// ENVIRONMENT (`proto::Environment`, required) // ENVIRONMENT (`proto::Environment`, required)
@ -67,8 +110,10 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// as well as virtual camera parameters. // as well as virtual camera parameters.
// //
// Output: // Output:
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, required): // MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, optional):
// A vector of face geometry data. // A vector of face geometry data if MULTI_FACE_LANDMARKS is connected .
// FACE_GEOMETRY (FaceGeometry, optional):
// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected.
// //
// Options: // Options:
// metadata_file (`ExternalFile`, optional): // metadata_file (`ExternalFile`, optional):
@ -81,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>(); cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
MP_RETURN_IF_ERROR(SanityCheck(cc));
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
.Set<std::vector<mediapipe::NormalizedLandmarkList>>(); .Set<std::vector<mediapipe::NormalizedLandmarkList>>();
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>(); cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
return absl::OkStatus(); return absl::OkStatus();
} else {
cc->Inputs()
.Tag(kFaceLandmarksTag)
.Set<mediapipe::NormalizedLandmarkList>();
cc->Outputs().Tag(kFaceGeometryTag).Set<FaceGeometry>();
return absl::OkStatus();
}
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
@ -112,7 +165,6 @@ class GeometryPipelineCalculator : public CalculatorBase {
ASSIGN_OR_RETURN(geometry_pipeline_, ASSIGN_OR_RETURN(geometry_pipeline_,
CreateGeometryPipeline(environment, metadata), CreateGeometryPipeline(environment, metadata),
_ << "Failed to create a geometry pipeline!"); _ << "Failed to create a geometry pipeline!");
return absl::OkStatus(); return absl::OkStatus();
} }
@ -121,12 +173,15 @@ class GeometryPipelineCalculator : public CalculatorBase {
// to have a non-empty packet. In case this requirement is not met, there's // to have a non-empty packet. In case this requirement is not met, there's
// nothing to be processed at the current timestamp. // nothing to be processed at the current timestamp.
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() || if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() ||
cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { (cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty() &&
cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty())) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& image_size = const auto& image_size =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
const auto& multi_face_landmarks = const auto& multi_face_landmarks =
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
@ -147,6 +202,25 @@ class GeometryPipelineCalculator : public CalculatorBase {
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>( .AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
multi_face_geometry.release()) multi_face_geometry.release())
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} else {
const auto& face_landmarks =
cc->Inputs()
.Tag(kMultiFaceLandmarksTag)
.Get<mediapipe::NormalizedLandmarkList>();
ASSIGN_OR_RETURN(
std::vector<FaceGeometry> multi_face_geometry,
geometry_pipeline_->EstimateFaceGeometry(
{face_landmarks}, //
/*frame_width*/ image_size.first,
/*frame_height*/ image_size.second),
_ << "Failed to estimate face geometry for multiple faces!");
cc->Outputs()
.Tag(kFaceGeometryTag)
.AddPacket(mediapipe::MakePacket<FaceGeometry>(multi_face_geometry[0])
.At(cc->InputTimestamp()));
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type")
licenses(["notice"]) licenses(["notice"])
@ -23,6 +24,16 @@ mediapipe_proto_library(
srcs = ["environment.proto"], srcs = ["environment.proto"],
) )
mediapipe_register_type(
base_name = "face_geometry",
include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"],
types = [
"::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry",
"::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>",
],
deps = [":face_geometry_cc_proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "face_geometry_proto", name = "face_geometry_proto",
srcs = ["face_geometry.proto"], srcs = ["face_geometry.proto"],

View File

@ -116,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_face_detector_graph_options(); options->mutable_face_detector_graph_options();
if (!face_detector_graph_options->base_options().has_model_asset()) { if (!face_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_detector_file, ASSIGN_OR_RETURN(const auto face_detector_file,
resources.GetModelFile(kFaceDetectorTFLiteName)); resources.GetFile(kFaceDetectorTFLiteName));
SetExternalFile(face_detector_file, SetExternalFile(face_detector_file,
face_detector_graph_options->mutable_base_options() face_detector_graph_options->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
@ -132,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!face_landmarks_detector_graph_options->base_options() if (!face_landmarks_detector_graph_options->base_options()
.has_model_asset()) { .has_model_asset()) {
ASSIGN_OR_RETURN(const auto face_landmarks_detector_file, ASSIGN_OR_RETURN(const auto face_landmarks_detector_file,
resources.GetModelFile(kFaceLandmarksDetectorTFLiteName)); resources.GetFile(kFaceLandmarksDetectorTFLiteName));
SetExternalFile( SetExternalFile(
face_landmarks_detector_file, face_landmarks_detector_file,
face_landmarks_detector_graph_options->mutable_base_options() face_landmarks_detector_graph_options->mutable_base_options()
@ -146,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());
absl::StatusOr<absl::string_view> face_blendshape_model = absl::StatusOr<absl::string_view> face_blendshape_model =
resources.GetModelFile(kFaceBlendshapeTFLiteName); resources.GetFile(kFaceBlendshapeTFLiteName);
if (face_blendshape_model.ok()) { if (face_blendshape_model.ok()) {
SetExternalFile(*face_blendshape_model, SetExternalFile(*face_blendshape_model,
face_landmarks_detector_graph_options face_landmarks_detector_graph_options
@ -327,7 +327,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
// Set the face geometry metdata file for // Set the face geometry metdata file for
// FaceGeometryFromLandmarksGraph. // FaceGeometryFromLandmarksGraph.
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file, ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
model_asset_bundle_resources->GetModelFile( model_asset_bundle_resources->GetFile(
kFaceGeometryPipelineMetadataName)); kFaceGeometryPipelineMetadataName));
SetExternalFile(face_geometry_pipeline_metadata_file, SetExternalFile(face_geometry_pipeline_metadata_file,
sc->MutableOptions<FaceLandmarkerGraphOptions>() sc->MutableOptions<FaceLandmarkerGraphOptions>()

View File

@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
GestureRecognizerGraphOptions* options, GestureRecognizerGraphOptions* options,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_landmarker_file, ASSIGN_OR_RETURN(const auto hand_landmarker_file,
resources.GetModelFile(kHandLandmarkerBundleAssetName)); resources.GetFile(kHandLandmarkerBundleAssetName));
auto* hand_landmarker_graph_options = auto* hand_landmarker_graph_options =
options->mutable_hand_landmarker_graph_options(); options->mutable_hand_landmarker_graph_options();
SetExternalFile(hand_landmarker_file, SetExternalFile(hand_landmarker_file,
@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode()); options->base_options().use_stream_mode());
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file,
const auto hand_gesture_recognizer_file, resources.GetFile(kHandGestureRecognizerBundleAssetName));
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
auto* hand_gesture_recognizer_graph_options = auto* hand_gesture_recognizer_graph_options =
options->mutable_hand_gesture_recognizer_graph_options(); options->mutable_hand_gesture_recognizer_graph_options();
SetExternalFile(hand_gesture_recognizer_file, SetExternalFile(hand_gesture_recognizer_file,
@ -127,7 +126,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->mutable_acceleration() ->mutable_acceleration()
->mutable_xnnpack(); ->mutable_xnnpack();
LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets "
<< "HandGestureRecognizerGraph acceleartion to Xnnpack."; << "HandGestureRecognizerGraph acceleration to Xnnpack.";
} }
hand_gesture_recognizer_graph_options->mutable_base_options() hand_gesture_recognizer_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode()); ->set_use_stream_mode(options->base_options().use_stream_mode());

View File

@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
HandGestureRecognizerGraphOptions* options, HandGestureRecognizerGraphOptions* options,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file, ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName)); resources.GetFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options = auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options(); options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file, SetExternalFile(gesture_embedder_file,
@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
options->base_options(), options->base_options(),
gesture_embedder_graph_options->mutable_base_options()); gesture_embedder_graph_options->mutable_base_options());
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
const auto canned_gesture_classifier_file, resources.GetFile(kCannedGestureClassifierTFLiteName));
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options = auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options(); options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile( SetExternalFile(
@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
canned_gesture_classifier_graph_options->mutable_base_options()); canned_gesture_classifier_graph_options->mutable_base_options());
const auto custom_gesture_classifier_file = const auto custom_gesture_classifier_file =
resources.GetModelFile(kCustomGestureClassifierTFLiteName); resources.GetFile(kCustomGestureClassifierTFLiteName);
if (custom_gesture_classifier_file.ok()) { if (custom_gesture_classifier_file.ok()) {
has_custom_gesture_classifier = true; has_custom_gesture_classifier = true;
auto* custom_gesture_classifier_graph_options = auto* custom_gesture_classifier_graph_options =

View File

@ -101,7 +101,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// three running modes: // three running modes:
// 1) Image mode for detecting hand landmarks on single image inputs. Users // 1) Image mode for detecting hand landmarks on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the // provide mediapipe::Image to the `Detect` method, and will receive the
// deteced hand landmarks results as the return value. // detected hand landmarks results as the return value.
// 2) Video mode for detecting hand landmarks on the decoded frames of a // 2) Video mode for detecting hand landmarks on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected // video. Users call `DetectForVideo` method, and will receive the detected
// hand landmarks results as the return value. // hand landmarks results as the return value.

View File

@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
options->mutable_hand_detector_graph_options(); options->mutable_hand_detector_graph_options();
if (!hand_detector_graph_options->base_options().has_model_asset()) { if (!hand_detector_graph_options->base_options().has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_detector_file, ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName)); resources.GetFile(kHandDetectorTFLiteName));
SetExternalFile(hand_detector_file, SetExternalFile(hand_detector_file,
hand_detector_graph_options->mutable_base_options() hand_detector_graph_options->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
if (!hand_landmarks_detector_graph_options->base_options() if (!hand_landmarks_detector_graph_options->base_options()
.has_model_asset()) { .has_model_asset()) {
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); resources.GetFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile( SetExternalFile(
hand_landmarks_detector_file, hand_landmarks_detector_file,
hand_landmarks_detector_graph_options->mutable_base_options() hand_landmarks_detector_graph_options->mutable_base_options()

View File

@ -409,7 +409,7 @@ REGISTER_MEDIAPIPE_GRAPH(
// - Accepts CPU input image and a vector of hand rect RoIs to detect the // - Accepts CPU input image and a vector of hand rect RoIs to detect the
// multiple hands landmarks enclosed by the RoIs. Output vectors of // multiple hands landmarks enclosed by the RoIs. Output vectors of
// hand landmarks related results, where each element in the vectors // hand landmarks related results, where each element in the vectors
// corrresponds to the result of the same hand. // corresponds to the result of the same hand.
// //
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image

View File

@ -52,7 +52,7 @@ constexpr char kMobileNetV3Embedder[] =
constexpr double kSimilarityTolerancy = 1e-6; constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result // Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder. // procduced by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) { void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings.size(), 1); EXPECT_EQ(result.embeddings.size(), 1);
EXPECT_EQ(result.embeddings[0].head_index, 0); EXPECT_EQ(result.embeddings[0].head_index, 0);

View File

@ -25,6 +25,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":image_segmenter_graph", ":image_segmenter_graph",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
@ -34,10 +35,13 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/util:label_map_cc_proto",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ],
) )

View File

@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
import "mediapipe/util/label_map.proto"; import "mediapipe/util/label_map.proto";
option java_package = "com.google.mediapipe.tasks";
option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto";
message TensorsToSegmentationCalculatorOptions { message TensorsToSegmentationCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional TensorsToSegmentationCalculatorOptions ext = 458105876; optional TensorsToSegmentationCalculatorOptions ext = 458105876;

View File

@ -15,15 +15,21 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
#include <optional>
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h"
#include "mediapipe/util/label_map.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -112,6 +118,39 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
return options_proto; return options_proto;
} }
absl::StatusOr<std::vector<std::string>> GetLabelsFromGraphConfig(
const CalculatorGraphConfig& graph_config) {
bool found_tensor_to_segmentation_calculator = false;
std::vector<std::string> labels;
for (const auto& node : graph_config.node()) {
if (node.calculator() ==
"mediapipe.tasks.TensorsToSegmentationCalculator") {
if (!found_tensor_to_segmentation_calculator) {
found_tensor_to_segmentation_calculator = true;
} else {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kFailedPrecondition,
"The graph has more than one "
"mediapipe.tasks.TensorsToSegmentationCalculator."));
}
TensorsToSegmentationCalculatorOptions options =
node.options().GetExtension(
TensorsToSegmentationCalculatorOptions::ext);
if (!options.label_items().empty()) {
for (int i = 0; i < options.label_items_size(); ++i) {
if (!options.label_items().contains(i)) {
return absl::Status(CreateStatusWithPayload(
absl::StatusCode::kFailedPrecondition,
absl::StrFormat("The lablemap have no expected key: %d.", i)));
}
labels.push_back(options.label_items().at(i).name());
}
}
}
}
return labels;
}
} // namespace } // namespace
absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create( absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
@ -140,13 +179,22 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
}; };
} }
return core::VisionTaskApiFactory::Create<ImageSegmenter,
auto image_segmenter =
core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterGraphOptionsProto>( ImageSegmenterGraphOptionsProto>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)); std::move(packets_callback));
if (!image_segmenter.ok()) {
return image_segmenter.status();
}
ASSIGN_OR_RETURN(
(*image_segmenter)->labels_,
GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig()));
return image_segmenter;
} }
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment( absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(

View File

@ -189,6 +189,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// Shuts down the ImageSegmenter when all works are done. // Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
// Get the category label list of the ImageSegmenter can recognize. For
// CATEGORY_MASK type, the index in the category mask corresponds to the
// category in the label list. For CONFIDENCE_MASK type, the output mask list
// at index corresponds to the category in the label list.
//
// If there is no labelmap provided in the model file, empty label list is
// returned.
std::vector<std::string> GetLabels() { return labels_; }
private:
std::vector<std::string> labels_;
}; };
} // namespace image_segmenter } // namespace image_segmenter

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h"
#include <array>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
@ -71,6 +72,13 @@ constexpr float kGoldenMaskSimilarity = 0.98;
// 20 means class index 2, etc. // 20 means class index 2, etc.
constexpr int kGoldenMaskMagnificationFactor = 10; constexpr int kGoldenMaskMagnificationFactor = 10;
constexpr std::array<absl::string_view, 21> kDeeplabLabelNames = {
"background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "dining table", "dog", "horse", "motorbike",
"person", "potted plant", "sheep", "sofa", "train",
"tv"};
// Intentionally converting output into CV_8UC1 and then again into CV_32FC1 // Intentionally converting output into CV_8UC1 and then again into CV_32FC1
// as expected outputs are stored in CV_8UC1, so this conversion allows to do // as expected outputs are stored in CV_8UC1, so this conversion allows to do
// fair comparison. // fair comparison.
@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) {
"channels = 3 or 4.")); "channels = 3 or 4."));
} }
TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
const auto& labels = segmenter->GetLabels();
ASSERT_FALSE(labels.empty());
ASSERT_EQ(labels.size(), kDeeplabLabelNames.size());
for (int i = 0; i < labels.size(); ++i) {
EXPECT_EQ(labels[i], kDeeplabLabelNames[i]);
}
}
class ImageModeTest : public tflite_shims::testing::Test {}; class ImageModeTest : public tflite_shims::testing::Test {};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) { TEST_F(ImageModeTest, SucceedsWithCategoryMask) {

View File

@ -108,31 +108,31 @@ std::vector<DetectionProto>
GenerateMobileSsdNoImageResizingFullExpectedResults() { GenerateMobileSsdNoImageResizingFullExpectedResults() {
return {ParseTextProtoOrDie<DetectionProto>(R"pb( return {ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.6328125 score: 0.6210937
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.59765625 score: 0.609375
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } bounding_box { xmin: 150 ymin: 78 width: 104 height: 223 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.5 score: 0.5
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } bounding_box { xmin: 64 ymin: 199 width: 42 height: 101 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "dog" label: "dog"
score: 0.48828125 score: 0.5
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 12 ymin: 110 width: 153 height: 193 } bounding_box { xmin: 14 ymin: 110 width: 153 height: 193 }
})pb")}; })pb")};
} }
@ -268,7 +268,7 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = options->result_callback =
[](absl::StatusOr<ObjectDetectorResult> detections, const Image& image, [](absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
int64 timestamp_ms) {}; int64_t timestamp_ms) {};
absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector = absl::StatusOr<std::unique_ptr<ObjectDetector>> object_detector =
ObjectDetector::Create(std::move(options)); ObjectDetector::Create(std::move(options));
EXPECT_EQ(object_detector.status().code(), EXPECT_EQ(object_detector.status().code(),
@ -381,28 +381,28 @@ TEST_F(ImageModeTest, Succeeds) {
score: 0.69921875 score: 0.69921875
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } bounding_box { xmin: 608 ymin: 164 width: 381 height: 432 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.64453125 score: 0.65625
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } bounding_box { xmin: 57 ymin: 398 width: 386 height: 196 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.51171875 score: 0.51171875
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } bounding_box { xmin: 256 ymin: 394 width: 173 height: 202 }
})pb"), })pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb( ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.48828125 score: 0.48828125
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } bounding_box { xmin: 360 ymin: 195 width: 330 height: 412 }
})pb")})); })pb")}));
} }
@ -484,10 +484,10 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
results, results,
ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb( ConvertToDetectionResult({ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat" label: "cat"
score: 0.6531269142 score: 0.650467276
location_data { location_data {
format: BOUNDING_BOX format: BOUNDING_BOX
bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 }
})pb")})); })pb")}));
} }
@ -507,9 +507,9 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
GenerateMobileSsdNoImageResizingFullExpectedResults(); GenerateMobileSsdNoImageResizingFullExpectedResults();
ExpectApproximatelyEqual( ExpectApproximatelyEqual(
results, ConvertToDetectionResult({full_expected_results[0], results, ConvertToDetectionResult(
full_expected_results[1], {full_expected_results[0], full_expected_results[1],
full_expected_results[2]})); full_expected_results[2], full_expected_results[3]}));
} }
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
@ -685,7 +685,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections, options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
@ -716,7 +716,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections, options->result_callback = [](absl::StatusOr<ObjectDetectorResult> detections,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64_t timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); MP_ASSERT_OK(object_detector->DetectAsync(image, 1));
@ -742,13 +742,13 @@ TEST_F(LiveStreamModeTest, Succeeds) {
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<ObjectDetectorResult> detection_results; std::vector<ObjectDetectorResult> detection_results;
std::vector<std::pair<int, int>> image_sizes; std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps; std::vector<int64_t> timestamps;
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
options->result_callback = options->result_callback =
[&detection_results, &image_sizes, &timestamps]( [&detection_results, &image_sizes, &timestamps](
absl::StatusOr<ObjectDetectorResult> detections, const Image& image, absl::StatusOr<ObjectDetectorResult> detections, const Image& image,
int64 timestamp_ms) { int64_t timestamp_ms) {
MP_ASSERT_OK(detections.status()); MP_ASSERT_OK(detections.status());
detection_results.push_back(std::move(detections).value()); detection_results.push_back(std::move(detections).value());
image_sizes.push_back({image.width(), image.height()}); image_sizes.push_back({image.width(), image.height()});
@ -775,7 +775,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.first, image.width());
EXPECT_EQ(image_size.second, image.height()); EXPECT_EQ(image_size.second, image.height());
} }
int64 timestamp_ms = -1; int64_t timestamp_ms = -1;
for (const auto& timestamp : timestamps) { for (const auto& timestamp : timestamps) {
EXPECT_GT(timestamp, timestamp_ms); EXPECT_GT(timestamp, timestamp_ms);
timestamp_ms = timestamp; timestamp_ms = timestamp;

View File

@ -0,0 +1,53 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
cc_library(
name = "pose_detector_graph",
srcs = ["pose_detector_graph.cc"],
deps = [
"//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
"//mediapipe/calculators/util:detection_projection_calculator",
"//mediapipe/calculators/util:detection_transformation_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)

View File

@ -0,0 +1,354 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h"
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_detector {
using ::mediapipe::NormalizedRect;
using ::mediapipe::Tensor;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::vision::pose_detector::proto::
PoseDetectorGraphOptions;
namespace {
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kAnchorsTag[] = "ANCHORS";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
constexpr char kPoseRectsTag[] = "POSE_RECTS";
constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
struct PoseDetectionOuts {
Source<std::vector<Detection>> pose_detections;
Source<std::vector<NormalizedRect>> pose_rects;
Source<std::vector<NormalizedRect>> expanded_pose_rects;
Source<Image> image;
};
// TODO: Configuration detection related calculators in pose
// detector with model metadata.
void ConfigureSsdAnchorsCalculator(
mediapipe::SsdAnchorsCalculatorOptions* options) {
// Dervied from
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
options->set_num_layers(5);
options->set_min_scale(0.1484375);
options->set_max_scale(0.75);
options->set_input_size_height(224);
options->set_input_size_width(224);
options->set_anchor_offset_x(0.5);
options->set_anchor_offset_y(0.5);
options->add_strides(8);
options->add_strides(16);
options->add_strides(32);
options->add_strides(32);
options->add_strides(32);
options->add_aspect_ratios(1.0);
options->set_fixed_anchor_size(true);
}
// TODO: Configuration detection related calculators in pose
// detector with model metadata.
void ConfigureTensorsToDetectionsCalculator(
const PoseDetectorGraphOptions& tasks_options,
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
// Dervied from
// mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt
options->set_num_classes(1);
options->set_num_boxes(2254);
options->set_num_coords(12);
options->set_box_coord_offset(0);
options->set_keypoint_coord_offset(4);
options->set_num_keypoints(4);
options->set_num_values_per_keypoint(2);
options->set_sigmoid_score(true);
options->set_score_clipping_thresh(100.0);
options->set_reverse_output_order(true);
options->set_min_score_thresh(tasks_options.min_detection_confidence());
options->set_x_scale(224.0);
options->set_y_scale(224.0);
options->set_w_scale(224.0);
options->set_h_scale(224.0);
}
void ConfigureNonMaxSuppressionCalculator(
const PoseDetectorGraphOptions& tasks_options,
mediapipe::NonMaxSuppressionCalculatorOptions* options) {
options->set_min_suppression_threshold(
tasks_options.min_suppression_threshold());
options->set_overlap_type(
mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
options->set_algorithm(
mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED);
}
// TODO: Configuration detection related calculators in pose
// detector with model metadata.
void ConfigureDetectionsToRectsCalculator(
mediapipe::DetectionsToRectsCalculatorOptions* options) {
options->set_rotation_vector_start_keypoint_index(0);
options->set_rotation_vector_end_keypoint_index(2);
options->set_rotation_vector_target_angle(90);
options->set_output_zero_rect_for_empty_detections(true);
}
// TODO: Configuration detection related calculators in pose
// detector with model metadata.
void ConfigureRectTransformationCalculator(
mediapipe::RectTransformationCalculatorOptions* options) {
options->set_scale_x(2.6);
options->set_scale_y(2.6);
options->set_shift_y(-0.5);
options->set_square_long(true);
}
} // namespace
// A "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" performs pose
// detection.
//
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection on. If
// not provided, whole image is used for pose detection.
//
// Outputs:
// DETECTIONS - std::vector<Detection>
// Detected pose with maximum `num_poses` specified in options.
// POSE_RECTS - std::vector<NormalizedRect>
// Detected pose bounding boxes in normalized coordinates.
// EXPANDED_POSE_RECTS - std::vector<NormalizedRect>
// Expanded pose bounding boxes in normalized coordinates so that bounding
// boxes likely contain the whole pose. This is usually used as RoI for pose
// landmarks detection to run on.
// IMAGE - Image
// The input image that the pose detector runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
// All returned coordinates are in the unrotated and uncropped input image
// coordinates system.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"
// input_stream: "IMAGE:image"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "DETECTIONS:palm_detections"
// output_stream: "POSE_RECTS:pose_rects"
// output_stream: "EXPANDED_POSE_RECTS:expanded_pose_rects"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.pose_detector.proto.PoseDetectorGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "pose_detection.tflite"
// }
// }
// min_detection_confidence: 0.5
// num_poses: 2
// }
// }
// }
class PoseDetectorGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<PoseDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto outs,
BuildPoseDetectionSubgraph(
sc->Options<PoseDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
outs.pose_detections >>
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
outs.pose_rects >>
graph.Out(kPoseRectsTag).Cast<std::vector<NormalizedRect>>();
outs.expanded_pose_rects >>
graph.Out(kExpandedPoseRectsTag).Cast<std::vector<NormalizedRect>>();
outs.image >> graph.Out(kImageTag).Cast<Image>();
return graph.GetConfig();
}
private:
absl::StatusOr<PoseDetectionOuts> BuildPoseDetectionSubgraph(
const PoseDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Image preprocessing subgraph to convert image to tensor for the tflite
// model.
auto& preprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.ImagePreprocessingGraph");
bool use_gpu =
components::processors::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph(
model_resources, use_gpu,
&preprocessing.GetOptions<
components::processors::proto::ImagePreprocessingGraphOptions>()));
auto& image_to_tensor_options =
*preprocessing
.GetOptions<components::processors::proto::
ImagePreprocessingGraphOptions>()
.mutable_image_to_tensor_options();
image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto matrix = preprocessing.Out(kMatrixTag);
auto image_size = preprocessing.Out(kImageSizeTag);
// Pose detection model inferece.
auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph);
preprocessed_tensors >> inference.In(kTensorsTag);
auto model_output_tensors =
inference.Out(kTensorsTag).Cast<std::vector<Tensor>>();
// Generates a single side packet containing a vector of SSD anchors.
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
ConfigureSsdAnchorsCalculator(
&ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>());
auto anchors = ssd_anchor.SideOut("");
// Converts output tensors to Detections.
auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator");
ConfigureTensorsToDetectionsCalculator(
subgraph_options,
&tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
model_output_tensors >> tensors_to_detections.In(kTensorsTag);
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
auto detections = tensors_to_detections.Out(kDetectionsTag);
// Non maximum suppression removes redundant face detections.
auto& non_maximum_suppression =
graph.AddNode("NonMaxSuppressionCalculator");
ConfigureNonMaxSuppressionCalculator(
subgraph_options,
&non_maximum_suppression
.GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>());
detections >> non_maximum_suppression.In("");
auto nms_detections = non_maximum_suppression.Out("");
// Projects detections back into the input image coordinates system.
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
nms_detections >> detection_projection.In(kDetectionsTag);
matrix >> detection_projection.In(kProjectionMatrixTag);
Source<std::vector<Detection>> pose_detections =
detection_projection.Out(kDetectionsTag).Cast<std::vector<Detection>>();
if (subgraph_options.has_num_poses()) {
// Clip face detections to maximum number of poses.
auto& clip_detection_vector_size =
graph.AddNode("ClipDetectionVectorSizeCalculator");
clip_detection_vector_size
.GetOptions<mediapipe::ClipVectorSizeCalculatorOptions>()
.set_max_vec_size(subgraph_options.num_poses());
pose_detections >> clip_detection_vector_size.In("");
pose_detections =
clip_detection_vector_size.Out("").Cast<std::vector<Detection>>();
}
// Converts results of pose detection into a rectangle (normalized by image
// size) that encloses the face and is rotated such that the line connecting
// left eye and right eye is aligned with the X-axis of the rectangle.
auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator");
ConfigureDetectionsToRectsCalculator(
&detections_to_rects
.GetOptions<mediapipe::DetectionsToRectsCalculatorOptions>());
image_size >> detections_to_rects.In(kImageSizeTag);
pose_detections >> detections_to_rects.In(kDetectionsTag);
auto pose_rects = detections_to_rects.Out(kNormRectsTag)
.Cast<std::vector<NormalizedRect>>();
// Expands and shifts the rectangle that contains the pose so that it's
// likely to cover the entire pose.
auto& rect_transformation = graph.AddNode("RectTransformationCalculator");
ConfigureRectTransformationCalculator(
&rect_transformation
.GetOptions<mediapipe::RectTransformationCalculatorOptions>());
pose_rects >> rect_transformation.In(kNormRectsTag);
image_size >> rect_transformation.In(kImageSizeTag);
auto expanded_pose_rects =
rect_transformation.Out("").Cast<std::vector<NormalizedRect>>();
// Calculator to convert relative detection bounding boxes to pixel
// detection bounding boxes.
auto& detection_transformation =
graph.AddNode("DetectionTransformationCalculator");
detection_projection.Out(kDetectionsTag) >>
detection_transformation.In(kDetectionsTag);
preprocessing.Out(kImageSizeTag) >>
detection_transformation.In(kImageSizeTag);
auto pose_pixel_detections =
detection_transformation.Out(kPixelDetectionsTag)
.Cast<std::vector<Detection>>();
return PoseDetectionOuts{
/* pose_detections= */ pose_pixel_detections,
/* pose_rects= */ pose_rects,
/* expanded_pose_rects= */ expanded_pose_rects,
/* image= */ preprocessing.Out(kImageTag).Cast<Image>()};
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::pose_detector::PoseDetectorGraph);
} // namespace pose_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,165 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace pose_detector {
namespace {
using ::file::Defaults;
using ::file::GetTextProto;
using ::mediapipe::NormalizedRect;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::TaskRunner;
using ::mediapipe::tasks::vision::pose_detector::proto::
PoseDetectorGraphOptions;
using ::testing::EqualsProto;
using ::testing::Pointwise;
using ::testing::TestParamInfo;
using ::testing::TestWithParam;
using ::testing::Values;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPoseDetectionModel[] = "pose_detection.tflite";
constexpr char kPortraitImage[] = "pose.jpg";
constexpr char kPoseExpectedDetection[] = "pose_expected_detection.pbtxt";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectName[] = "norm_rect";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionsName[] = "detections";
constexpr float kPoseDetectionMaxDiff = 0.01;
// Helper function to create a TaskRunner.
absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
absl::string_view model_name) {
Graph graph;
auto& pose_detector_graph =
graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph");
auto options = std::make_unique<PoseDetectorGraphOptions>();
options->mutable_base_options()->mutable_model_asset()->set_file_name(
JoinPath("./", kTestDataDirectory, model_name));
options->set_min_detection_confidence(0.6);
options->set_min_suppression_threshold(0.3);
pose_detector_graph.GetOptions<PoseDetectorGraphOptions>().Swap(
options.get());
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
pose_detector_graph.In(kImageTag);
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
pose_detector_graph.In(kNormRectTag);
pose_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
return TaskRunner::Create(
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
Detection GetExpectedPoseDetectionResult(absl::string_view file_name) {
Detection detection;
CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name),
&detection, Defaults()))
<< "Expected pose detection result does not exist.";
return detection;
}
struct TestParams {
// The name of this test, for convenience when displaying test results.
std::string test_name;
// The filename of pose landmark detection model.
std::string pose_detection_model_name;
// The filename of test image.
std::string test_image_name;
// Expected pose detection results.
std::vector<Detection> expected_result;
};
class PoseDetectorGraphTest : public testing::TestWithParam<TestParams> {};
TEST_P(PoseDetectorGraphTest, Succeed) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name)));
NormalizedRect input_norm_rect;
input_norm_rect.set_x_center(0.5);
input_norm_rect.set_y_center(0.5);
input_norm_rect.set_width(1.0);
input_norm_rect.set_height(1.0);
MP_ASSERT_OK_AND_ASSIGN(
auto task_runner, CreateTaskRunner(GetParam().pose_detection_model_name));
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
MP_ASSERT_OK(output_packets);
const std::vector<Detection>& pose_detections =
(*output_packets)[kDetectionsName].Get<std::vector<Detection>>();
EXPECT_THAT(pose_detections, Pointwise(Approximately(Partially(EqualsProto()),
kPoseDetectionMaxDiff),
GetParam().expected_result));
}
INSTANTIATE_TEST_SUITE_P(
PoseDetectorGraphTest, PoseDetectorGraphTest,
Values(TestParams{.test_name = "DetectPose",
.pose_detection_model_name = kPoseDetectionModel,
.test_image_name = kPortraitImage,
.expected_result = {GetExpectedPoseDetectionResult(
kPoseExpectedDetection)}}),
[](const TestParamInfo<PoseDetectorGraphTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace pose_detector
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "pose_detector_graph_options_proto",
srcs = ["pose_detector_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,45 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.pose_detector.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.posedetector.proto";
option java_outer_classname = "PoseDetectorGraphOptionsProto";
message PoseDetectorGraphOptions {
extend mediapipe.CalculatorOptions {
optional PoseDetectorGraphOptions ext = 514774813;
}
// Base options for configuring Task library, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Minimum confidence value ([0.0, 1.0]) for confidence score to be considered
// successfully detecting a pose in the image.
optional float min_detection_confidence = 2 [default = 0.5];
// IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered
// duplicate detections.
optional float min_suppression_threshold = 3 [default = 0.5];
// Maximum number of poses to detect in the image.
optional int32 num_poses = 4;
}

View File

@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core;
import android.content.Context; import android.content.Context;
import android.util.Log; import android.util.Log;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.Graph;
@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable {
} }
} }
public CalculatorGraphConfig getCalculatorGraphConfig() {
return graph.getCalculatorGraphConfig();
}
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) { private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
if (!graphStarted.get()) { if (!graphStarted.get()) {
reportError( reportError(

View File

@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",

View File

@ -197,6 +197,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",

View File

@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter;
import android.content.Context; import android.content.Context;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder; import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
"mediapipe.tasks.TensorsToSegmentationCalculator";
private boolean hasResultListener = false; private boolean hasResultListener = false;
private List<String> labels = new ArrayList<>();
/** /**
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
this.hasResultListener = hasResultListener; this.hasResultListener = hasResultListener;
populateLabels();
}
/**
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
*
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
*/
private void populateLabels() {
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
boolean foundTensorsToSegmentation = false;
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
if (foundTensorsToSegmentation) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
}
foundTensorsToSegmentation = true;
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
node.getOptions()
.getExtension(
TensorsToSegmentationCalculatorOptionsProto
.TensorsToSegmentationCalculatorOptions.ext);
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
Long labelKey = Long.valueOf(i);
if (!options.getLabelItemsMap().containsKey(labelKey)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The lablemap have no expected key: " + labelKey);
}
labels.add(options.getLabelItemsMap().get(labelKey).getName());
}
}
}
} }
/** /**
@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
sendLiveStreamData(image, imageProcessingOptions, timestampMs); sendLiveStreamData(image, imageProcessingOptions, timestampMs);
} }
/**
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
* type, the output mask list at index corresponds to the category in the label list.
*
* <p>If there is no labelmap provided in the model file, empty label list is returned.
*/
List<String> getLabels() {
return labels;
}
/** Options for setting up an {@link ImageSegmenter}. */ /** Options for setting up an {@link ImageSegmenter}. */
@AutoValue @AutoValue
public abstract static class ImageSegmenterOptions extends TaskOptions { public abstract static class ImageSegmenterOptions extends TaskOptions {

View File

@ -7,6 +7,7 @@ VERS_1.0 {
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig;
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;

View File

@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -135,6 +136,45 @@ public class ImageSegmenterTest {
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// } // }
@Test
public void getLabels_success() throws Exception {
final List<String> expectedLabels =
Arrays.asList(
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv");
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
List<String> actualLabels = imageSegmenter.getLabels();
assertThat(actualLabels.size()).isEqualTo(expectedLabels.size());
for (int i = 0; i < actualLabels.size(); i++) {
assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i));
}
}
} }
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)

View File

@ -7,7 +7,9 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
exports_files(["metadata_schema.fbs"]) exports_files(glob([
"*.fbs",
]))
# Generic schema for model metadata. # Generic schema for model metadata.
flatbuffer_cc_library( flatbuffer_cc_library(
@ -24,3 +26,13 @@ flatbuffer_py_library(
name = "metadata_schema_py", name = "metadata_schema_py",
srcs = ["metadata_schema.fbs"], srcs = ["metadata_schema.fbs"],
) )
flatbuffer_cc_library(
name = "image_segmenter_metadata_schema_cc",
srcs = ["image_segmenter_metadata_schema.fbs"],
)
flatbuffer_py_library(
name = "image_segmenter_metadata_schema_py",
srcs = ["image_segmenter_metadata_schema.fbs"],
)

View File

@ -0,0 +1,59 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace mediapipe.tasks;
// Image segmenter metadata contains information specific for the image
// segmentation task. The metadata can be added in
// SubGraphMetadata.custom_metadata [1] in model metadata.
// [1]: https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L685
// ImageSegmenterOptions.min_parser_version indicates the minimum necessary
// image segmenter metadata parser version to fully understand all fields in a
// given metadata flatbuffer. This min_parser_version is specific for the
// image segmenter metadata defined in this schema file.
//
// New fields and types will have associated comments with the schema version
// for which they were added.
//
// Schema Semantic version: 1.0.0
// This indicates the flatbuffer compatibility. The number will bump up when a
// break change is applied to the schema, such as removing fields or adding new
// fields to the middle of a table.
file_identifier "V001";
// History:
// 1.0.0 - Initial version.
// Supported activation functions.
enum Activation: byte {
NONE = 0,
SIGMOID = 1,
SOFTMAX = 2
}
table ImageSegmenterOptions {
// The activation function of the output layer in the image segmenter.
activation: Activation;
// The minimum necessary image segmenter metadata parser version to fully
// understand all fields in a given metadata flatbuffer. This field is
// automaticaly populated by the MetadataPopulator when the metadata is
// populated into a TFLite model. This min_parser_version is specific for the
// image segmenter metadata defined in this schema file.
min_parser_version:string;
}
root_type ImageSegmenterOptions;

View File

@ -233,7 +233,7 @@ table ImageProperties {
// //
// <Codegen usage>: // <Codegen usage>:
// Input image tensors: NA. // Input image tensors: NA.
// Output image tensors: parses the values into a data stucture that represents // Output image tensors: parses the values into a data structure that represents
// bounding boxes. For example, in the generated wrapper for Android, it returns // bounding boxes. For example, in the generated wrapper for Android, it returns
// the output as android.graphics.Rect objects. // the output as android.graphics.Rect objects.
enum BoundingBoxType : byte { enum BoundingBoxType : byte {
@ -389,7 +389,7 @@ table NormalizationOptions{
// mean and std are normalization parameters. Tensor values are normalized // mean and std are normalization parameters. Tensor values are normalized
// on a per-channel basis, by the formula // on a per-channel basis, by the formula
// (x - mean) / std. // (x - mean) / std.
// If there is only one value in mean or std, we'll propogate the value to // If there is only one value in mean or std, we'll propagate the value to
// all channels. // all channels.
// //
// Quantized models share the same normalization parameters as their // Quantized models share the same normalization parameters as their
@ -526,7 +526,7 @@ table Stats {
// Max and min are not currently used in tflite.support codegen. They mainly // Max and min are not currently used in tflite.support codegen. They mainly
// serve as references for users to better understand the model. They can also // serve as references for users to better understand the model. They can also
// be used to validate model pre/post processing results. // be used to validate model pre/post processing results.
// If there is only one value in max or min, we'll propogate the value to // If there is only one value in max or min, we'll propagate the value to
// all channels. // all channels.
// Per-channel maximum value of the tensor. // Per-channel maximum value of the tensor.
@ -542,7 +542,7 @@ table Stats {
// has four outputs: classes, scores, bounding boxes, and number of detections. // has four outputs: classes, scores, bounding boxes, and number of detections.
// If the four outputs are bundled together using TensorGroup (for example, // If the four outputs are bundled together using TensorGroup (for example,
// named as "detection result"), the codegen tool will generate the class, // named as "detection result"), the codegen tool will generate the class,
// `DetectionResult`, which contains the class, score, and bouding box. And the // `DetectionResult`, which contains the class, score, and bounding box. And the
// outputs of the model will be converted to a list of `DetectionResults` and // outputs of the model will be converted to a list of `DetectionResults` and
// the number of detection. Note that the number of detection is a single // the number of detection. Note that the number of detection is a single
// number, therefore is inappropriate for the list of `DetectionResult`. // number, therefore is inappropriate for the list of `DetectionResult`.
@ -624,7 +624,7 @@ table SubGraphMetadata {
// A description explains details about what the subgraph does. // A description explains details about what the subgraph does.
description:string; description:string;
// Metadata of all input tensors used in this subgraph. It matches extactly // Metadata of all input tensors used in this subgraph. It matches exactly
// the input tensors specified by `SubGraph.inputs` in the TFLite // the input tensors specified by `SubGraph.inputs` in the TFLite
// schema.fbs file[2]. The number of `TensorMetadata` in the array should // schema.fbs file[2]. The number of `TensorMetadata` in the array should
// equal to the number of indices in `SubGraph.inputs`. // equal to the number of indices in `SubGraph.inputs`.
@ -634,7 +634,7 @@ table SubGraphMetadata {
// Determines how to process the inputs. // Determines how to process the inputs.
input_tensor_metadata:[TensorMetadata]; input_tensor_metadata:[TensorMetadata];
// Metadata of all output tensors used in this subgraph. It matches extactly // Metadata of all output tensors used in this subgraph. It matches exactly
// the output tensors specified by `SubGraph.outputs` in the TFLite // the output tensors specified by `SubGraph.outputs` in the TFLite
// schema.fbs file[2]. The number of `TensorMetadata` in the array should // schema.fbs file[2]. The number of `TensorMetadata` in the array should
// equal to the number of indices in `SubGraph.outputs`. // equal to the number of indices in `SubGraph.outputs`.
@ -724,7 +724,7 @@ table ModelMetadata {
// number among the versions of all the fields populated and the smallest // number among the versions of all the fields populated and the smallest
// compatible version indicated by the file identifier. // compatible version indicated by the file identifier.
// //
// This field is automaticaly populated by the MetadataPopulator when // This field is automatically populated by the MetadataPopulator when
// the metadata is populated into a TFLite model. // the metadata is populated into a TFLite model.
min_parser_version:string; min_parser_version:string;
} }

View File

@ -17,10 +17,13 @@
import copy import copy
import inspect import inspect
import io import io
import json
import logging
import os import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
from typing import Dict, Optional
import warnings import warnings
import zipfile import zipfile
@ -789,13 +792,43 @@ class MetadataDisplayer(object):
return [] return []
def _get_custom_metadata(metadata_buffer: bytes, name: str):
"""Gets the custom metadata in metadata_buffer based on the name.
Args:
metadata_buffer: valid metadata buffer in bytes.
name: custom metadata name.
Returns:
Index of custom metadata, custom metadata flatbuffer. Returns (None, None)
if the custom metadata is not found.
"""
model_metadata = _metadata_fb.ModelMetadata.GetRootAs(metadata_buffer)
subgraph = model_metadata.SubgraphMetadata(0)
if subgraph is None or subgraph.CustomMetadataIsNone():
return None, None
for i in range(subgraph.CustomMetadataLength()):
custom_metadata = subgraph.CustomMetadata(i)
if custom_metadata.Name().decode("utf-8") == name:
return i, custom_metadata.DataAsNumpy().tobytes()
return None, None
# Create an individual method for getting the metadata json file, so that it can # Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util. # be used as a standalone util.
def convert_to_json(metadata_buffer): def convert_to_json(
metadata_buffer, custom_metadata_schema: Optional[Dict[str, str]] = None
) -> str:
"""Converts the metadata into a json string. """Converts the metadata into a json string.
Args: Args:
metadata_buffer: valid metadata buffer in bytes. metadata_buffer: valid metadata buffer in bytes.
custom_metadata_schema: A dict of custom metadata schema, in which key is
custom metadata name [1], value is the filepath that defines custom
metadata schema. For intance, custom_metadata_schema =
{"SEGMENTER_METADATA": "metadata/vision_tasks_metadata_schema.fbs"}. [1]:
https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L612
Returns: Returns:
Metadata in JSON format. Metadata in JSON format.
@ -803,7 +836,6 @@ def convert_to_json(metadata_buffer):
Raises: Raises:
ValueError: error occured when parsing the metadata schema file. ValueError: error occured when parsing the metadata schema file.
""" """
opt = _pywrap_flatbuffers.IDLOptions() opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt) parser = _pywrap_flatbuffers.Parser(opt)
@ -811,7 +843,35 @@ def convert_to_json(metadata_buffer):
metadata_schema_content = f.read() metadata_schema_content = f.read()
if not parser.parse(metadata_schema_content): if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer) # Json content which may contain binary custom metadata.
raw_json_content = _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
if not custom_metadata_schema:
return raw_json_content
json_data = json.loads(raw_json_content)
# Gets the custom metadata by name and parse the binary custom metadata into
# human readable json content.
for name, schema_file in custom_metadata_schema.items():
idx, custom_metadata = _get_custom_metadata(metadata_buffer, name)
if not custom_metadata:
logging.info(
"No custom metadata with name %s in metadata flatbuffer.", name
)
continue
_assert_file_exist(schema_file)
with _open_file(schema_file, "rb") as f:
custom_metadata_schema_content = f.read()
if not parser.parse(custom_metadata_schema_content):
raise ValueError(
"Cannot parse custom metadata schema. Reason: " + parser.error
)
custom_metadata_json = _pywrap_flatbuffers.generate_text(
parser, custom_metadata
)
json_meta = json_data["subgraph_metadata"][0]["custom_metadata"][idx]
json_meta["name"] = name
json_meta["data"] = json.loads(custom_metadata_json)
return json.dumps(json_data, indent=2)
def _assert_file_exist(filename): def _assert_file_exist(filename):

View File

@ -50,6 +50,20 @@ py_library(
deps = [":metadata_writer"], deps = [":metadata_writer"],
) )
py_library(
name = "image_segmenter",
srcs = ["image_segmenter.py"],
data = ["//mediapipe/tasks/metadata:image_segmenter_metadata_schema.fbs"],
deps = [
":metadata_info",
":metadata_writer",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py",
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/python/metadata",
"@flatbuffers//:runtime_py",
],
)
py_library( py_library(
name = "object_detector", name = "object_detector",
srcs = ["object_detector.py"], srcs = ["object_detector.py"],

View File

@ -0,0 +1,161 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Writes metadata and label file to the image segmenter models."""
import enum
from typing import List, Optional
import flatbuffers
from mediapipe.tasks.metadata import image_segmenter_metadata_schema_py_generated as _segmenter_metadata_fb
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.python.metadata import metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "ImageSegmenter"
_MODEL_DESCRIPTION = (
"Semantic image segmentation predicts whether each pixel "
"of an image is associated with a certain class."
)
# Metadata Schema file for image segmenter.
_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile(
"../../../metadata/image_segmenter_metadata_schema.fbs",
)
# Metadata name in custom metadata field. The metadata name is used to get
# image segmenter metadata from SubGraphMetadata.custom_metadata and
# shouldn't be changed.
_METADATA_NAME = "SEGMENTER_METADATA"
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer: bytearray) -> str:
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occured when parsing the metadata schema file.
"""
return metadata.convert_to_json(
metadata_buffer,
custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE},
)
class ImageSegmenterOptionsMd(metadata_info.CustomMetadataMd):
"""Image segmenter options metadata."""
_METADATA_FILE_IDENTIFIER = b"V001"
def __init__(self, activation: Activation) -> None:
"""Creates an ImageSegmenterOptionsMd object.
Args:
activation: activation function of the output layer in the image
segmenter.
"""
self.activation = activation
super().__init__(name=_METADATA_NAME)
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
"""Creates the image segmenter options metadata.
Returns:
A Flatbuffers Python object of the custom metadata including image
segmenter options metadata.
"""
segmenter_options = _segmenter_metadata_fb.ImageSegmenterOptionsT()
segmenter_options.activation = self.activation.value
# Get the image segmenter options flatbuffer.
b = flatbuffers.Builder(0)
b.Finish(segmenter_options.Pack(b), self._METADATA_FILE_IDENTIFIER)
segmenter_options_buf = b.Output()
# Add the image segmenter options flatbuffer in custom metadata.
custom_metadata = _metadata_fb.CustomMetadataT()
custom_metadata.name = self.name
custom_metadata.data = segmenter_options_buf
return custom_metadata
class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata for image segmenter."""
@classmethod
def create(
cls,
model_buffer: bytearray,
input_norm_mean: List[float],
input_norm_std: List[float],
labels: Optional[metadata_writer.Labels] = None,
activation: Optional[Activation] = None,
) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for image segmenter.
The parameters required in this method are mandatory when using MediaPipe
Tasks.
Example usage:
metadata_writer = image_segmenter.Metadatawriter.create(model_buffer, ...)
tflite_content, json_content = metadata_writer.populate()
When calling `populate` function in this class, it returns TfLite content
and JSON content. Note that only the output TFLite is used for deployment.
The output JSON content is used to interpret the metadata content.
Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file.
input_norm_mean: the mean value used in the input tensor normalization
[1].
input_norm_std: the std value used in the input tensor normalizarion [1].
labels: an instance of Labels helper class used in the output category
tensor [2].
activation: activation function for the output layer.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_segmentation_output(labels=labels)
if activation is not None:
option_md = ImageSegmenterOptionsMd(activation)
writer.add_custom_metadata(option_md)
return cls(writer)
def populate(self) -> tuple[bytearray, str]:
model_buf, _ = super().populate()
metadata_buf = metadata.get_metadata_buffer(model_buf)
json_content = convert_to_json(metadata_buf)
return model_buf, json_content

View File

@ -1030,6 +1030,52 @@ class TensorGroupMd:
return group return group
class SegmentationMaskMd(TensorMd):
"""A container for the segmentation mask metadata information."""
# The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where
# N is the number of objects that the segmentation model can recognize. The
# output tensor is essentially a list of grayscale bitmaps, where each value
# is the probability of the corresponding pixel belonging to a certain object
# type. Therefore, the content dimension range of the output tensor is [1, 2].
_CONTENT_DIM_MIN = 1
_CONTENT_DIM_MAX = 2
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
):
self.name = name
self.description = description
associated_files = label_files or []
super().__init__(
name=name, description=description, associated_files=associated_files
)
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the metadata for the segmentation masks tensor."""
masks_metadata = super().create_metadata()
# Create tensor content information.
content = _metadata_fb.ContentT()
content.contentProperties = _metadata_fb.ImagePropertiesT()
content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE
content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties
)
# Add the content range. See
# https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L323-L385
dim_range = _metadata_fb.ValueRangeT()
dim_range.min = self._CONTENT_DIM_MIN
dim_range.max = self._CONTENT_DIM_MAX
content.range = dim_range
masks_metadata.content = content
return masks_metadata
class CustomMetadataMd(abc.ABC): class CustomMetadataMd(abc.ABC):
"""An abstract class of a container for the custom metadata information.""" """An abstract class of a container for the custom metadata information."""

View File

@ -34,6 +34,10 @@ _INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
'text to be processed.') 'text to be processed.')
_OUTPUT_CLASSIFICATION_NAME = 'score' _OUTPUT_CLASSIFICATION_NAME = 'score'
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.' _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
_OUTPUT_SEGMENTATION_MASKS_NAME = 'segmentation_masks'
_OUTPUT_SEGMENTATION_MASKS_DESCRIPTION = (
'Masks over the target objects with high accuracy.'
)
# Detection tensor result to be grouped together. # Detection tensor result to be grouped together.
_DETECTION_GROUP_NAME = 'detection_result' _DETECTION_GROUP_NAME = 'detection_result'
# File name to export score calibration parameters. # File name to export score calibration parameters.
@ -657,6 +661,32 @@ class MetadataWriter(object):
self._output_group_mds.append(group_md) self._output_group_mds.append(group_md)
return self return self
def add_segmentation_output(
self,
labels: Optional[Labels] = None,
name: str = _OUTPUT_SEGMENTATION_MASKS_NAME,
description: str = _OUTPUT_SEGMENTATION_MASKS_DESCRIPTION,
) -> 'MetadataWriter':
"""Adds a segmentation head metadata for segmentation output tensor.
Args:
labels: an instance of Labels helper class.
name: Metadata name of the tensor. Note that this is different from tensor
name in the flatbuffer.
description: human readable description of what the output is.
Returns:
The current Writer instance to allow chained operation.
"""
label_files = self._create_label_file_md(labels)
output_md = metadata_info.SegmentationMaskMd(
name=name,
description=description,
label_files=label_files,
)
self._output_mds.append(output_md)
return self
def add_feature_output(self, def add_feature_output(self,
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None) -> 'MetadataWriter': description: Optional[str] = None) -> 'MetadataWriter':

View File

@ -91,3 +91,18 @@ py_test(
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )
py_test(
name = "image_segmenter_test",
srcs = ["image_segmenter_test.py"],
data = [
"//mediapipe/tasks/testdata/metadata:data_files",
"//mediapipe/tasks/testdata/metadata:model_files",
],
deps = [
"//mediapipe/tasks/python/metadata",
"//mediapipe/tasks/python/metadata/metadata_writers:image_segmenter",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,98 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metadata_writer.image_segmenter."""
import os
from absl.testing import absltest
from mediapipe.tasks.python.metadata import metadata
from mediapipe.tasks.python.metadata.metadata_writers import image_segmenter
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
_MODEL_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_metadata.tflite")
)
_LABEL_FILE_NAME = "labels.txt"
_LABEL_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "segmenter_labelmap.txt")
)
_NORM_MEAN = 127.5
_NORM_STD = 127.5
_JSON_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "deeplabv3.json")
)
_JSON_FILE_WITHOUT_LABELS = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_labels.json")
)
_JSON_FILE_WITH_ACTIVATION = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "deeplabv3_with_activation.json")
)
class ImageSegmenterTest(absltest.TestCase):
def test_write_metadata(self):
with open(_MODEL_FILE, "rb") as f:
model_buffer = f.read()
writer = image_segmenter.MetadataWriter.create(
bytearray(model_buffer),
[_NORM_MEAN],
[_NORM_STD],
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
)
tflite_content, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f:
expected_json = f.read().strip()
self.assertEqual(metadata_json, expected_json)
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
label_file_buffer = displayer.get_associated_file_buffer(_LABEL_FILE_NAME)
with open(_LABEL_FILE, "rb") as f:
expected_labelfile_buffer = f.read()
self.assertEqual(label_file_buffer, expected_labelfile_buffer)
def test_write_metadata_without_labels(self):
with open(_MODEL_FILE, "rb") as f:
model_buffer = f.read()
writer = image_segmenter.MetadataWriter.create(
bytearray(model_buffer),
[_NORM_MEAN],
[_NORM_STD],
)
_, metadata_json = writer.populate()
with open(_JSON_FILE_WITHOUT_LABELS, "r") as f:
expected_json = f.read().strip()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_with_activation(self):
with open(_MODEL_FILE, "rb") as f:
model_buffer = f.read()
writer = image_segmenter.MetadataWriter.create(
bytearray(model_buffer),
[_NORM_MEAN],
[_NORM_STD],
activation=image_segmenter.Activation.SIGMOID,
)
_, metadata_json = writer.populate()
with open(_JSON_FILE_WITH_ACTIVATION, "r") as f:
expected_json = f.read().strip()
self.assertEqual(metadata_json, expected_json)
if __name__ == "__main__":
absltest.main()

View File

@ -455,6 +455,27 @@ class TensorGroupMdMdTest(absltest.TestCase):
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)
class SegmentationMaskMdTest(absltest.TestCase):
_NAME = "segmentation_masks"
_DESCRIPTION = "Masks over the target objects."
_EXPECTED_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "segmentation_mask_meta.json")
)
def test_create_metadata_should_succeed(self):
segmentation_mask_md = metadata_info.SegmentationMaskMd(
name=self._NAME, description=self._DESCRIPTION
)
metadata = segmentation_mask_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(metadata)
)
with open(self._EXPECTED_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor( def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata. # Create a dummy model using the tensor metadata.

View File

@ -42,48 +42,62 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
_IMAGE_FILE = 'cats_and_dogs.jpg' _IMAGE_FILE = 'cats_and_dogs.jpg'
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ _EXPECTED_DETECTION_RESULT = _DetectionResult(
detections=[
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=608, origin_y=161, width=381, height=439), origin_x=608, origin_y=161, width=381, height=439
),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
score=0.69921875, score=0.69921875,
display_name=None, display_name=None,
category_name='cat') category_name='cat',
]), )
],
),
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=60, origin_y=398, width=386, height=196), origin_x=60, origin_y=398, width=386, height=196
),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
score=0.64453125, score=0.64453125,
display_name=None, display_name=None,
category_name='cat') category_name='cat',
]), )
],
),
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=256, origin_y=395, width=173, height=202), origin_x=256, origin_y=395, width=173, height=202
),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
score=0.51171875, score=0.51171875,
display_name=None, display_name=None,
category_name='cat') category_name='cat',
]), )
],
),
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=362, origin_y=191, width=325, height=419), origin_x=362, origin_y=191, width=325, height=419
),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
score=0.48828125, score=0.48828125,
display_name=None, display_name=None,
category_name='cat') category_name='cat',
]) )
]) ],
),
]
)
_ALLOW_LIST = ['cat', 'dog'] _ALLOW_LIST = ['cat', 'dog']
_DENY_LIST = ['cat'] _DENY_LIST = ['cat']
_SCORE_THRESHOLD = 0.3 _SCORE_THRESHOLD = 0.3

View File

@ -28,6 +28,10 @@ mediapipe_files(srcs = [
"category_tensor_float_meta.json", "category_tensor_float_meta.json",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
"coco_ssd_mobilenet_v1_score_calibration.json", "coco_ssd_mobilenet_v1_score_calibration.json",
"deeplabv3.json",
"deeplabv3_with_activation.json",
"deeplabv3_without_labels.json",
"deeplabv3_without_metadata.tflite",
"efficientdet_lite0_v1.json", "efficientdet_lite0_v1.json",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"labelmap.txt", "labelmap.txt",
@ -44,6 +48,8 @@ mediapipe_files(srcs = [
"mobilenet_v2_1.0_224_without_metadata.tflite", "mobilenet_v2_1.0_224_without_metadata.tflite",
"movie_review.tflite", "movie_review.tflite",
"score_calibration.csv", "score_calibration.csv",
"segmentation_mask_meta.json",
"segmenter_labelmap.txt",
"ssd_mobilenet_v1_no_metadata.json", "ssd_mobilenet_v1_no_metadata.json",
"ssd_mobilenet_v1_no_metadata.tflite", "ssd_mobilenet_v1_no_metadata.tflite",
"tensor_group_meta.json", "tensor_group_meta.json",
@ -87,6 +93,7 @@ filegroup(
"30k-clean.model", "30k-clean.model",
"bert_text_classifier_no_metadata.tflite", "bert_text_classifier_no_metadata.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
"deeplabv3_without_metadata.tflite",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-custom-metadata.tflite",
"mobile_ica_8bit-with-large-min-parser-version.tflite", "mobile_ica_8bit-with-large-min-parser-version.tflite",
@ -116,6 +123,9 @@ filegroup(
"classification_tensor_uint8_meta.json", "classification_tensor_uint8_meta.json",
"classification_tensor_unsupported_meta.json", "classification_tensor_unsupported_meta.json",
"coco_ssd_mobilenet_v1_score_calibration.json", "coco_ssd_mobilenet_v1_score_calibration.json",
"deeplabv3.json",
"deeplabv3_with_activation.json",
"deeplabv3_without_labels.json",
"efficientdet_lite0_v1.json", "efficientdet_lite0_v1.json",
"external_file", "external_file",
"feature_tensor_meta.json", "feature_tensor_meta.json",
@ -140,6 +150,8 @@ filegroup(
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",
"score_thresholding_meta.json", "score_thresholding_meta.json",
"segmentation_mask_meta.json",
"segmenter_labelmap.txt",
"sentence_piece_tokenizer_meta.json", "sentence_piece_tokenizer_meta.json",
"ssd_mobilenet_v1_no_metadata.json", "ssd_mobilenet_v1_no_metadata.json",
"tensor_group_meta.json", "tensor_group_meta.json",

View File

@ -0,0 +1,66 @@
{
"name": "ImageSegmenter",
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
127.5
],
"std": [
127.5
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
-1.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "segmentation_masks",
"description": "Masks over the target objects with high accuracy.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "GRAYSCALE"
},
"range": {
"min": 1,
"max": 2
}
},
"stats": {},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -0,0 +1,67 @@
{
"name": "ImageSegmenter",
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
127.5
],
"std": [
127.5
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
-1.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "segmentation_masks",
"description": "Masks over the target objects with high accuracy.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "GRAYSCALE"
},
"range": {
"min": 1,
"max": 2
}
},
"stats": {}
}
],
"custom_metadata": [
{
"name": "SEGMENTER_METADATA",
"data": {
"activation": "SIGMOID"
}
}
]
}
],
"min_parser_version": "1.5.0"
}

View File

@ -0,0 +1,59 @@
{
"name": "ImageSegmenter",
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
127.5
],
"std": [
127.5
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
-1.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "segmentation_masks",
"description": "Masks over the target objects with high accuracy.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "GRAYSCALE"
},
"range": {
"min": 1,
"max": 2
}
},
"stats": {}
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -0,0 +1,24 @@
{
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "segmentation_masks",
"description": "Masks over the target objects.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "GRAYSCALE"
},
"range": {
"min": 1,
"max": 2
}
},
"stats": {
}
}
]
}
]
}

View File

@ -0,0 +1,21 @@
background
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
dining table
dog
horse
motorbike
person
potted plant
sheep
sofa
train
tv

View File

@ -70,6 +70,8 @@ mediapipe_files(srcs = [
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg", "portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"pose.jpg",
"pose_detection.tflite",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
@ -78,6 +80,8 @@ mediapipe_files(srcs = [
"selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3.tflite",
"selfie_segm_144_256_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg",
"selfie_segmentation.tflite",
"selfie_segmentation_landscape.tflite",
"thumb_up.jpg", "thumb_up.jpg",
"victory.jpg", "victory.jpg",
]) ])
@ -125,6 +129,7 @@ filegroup(
"portrait.jpg", "portrait.jpg",
"portrait_hair_expected_mask.jpg", "portrait_hair_expected_mask.jpg",
"portrait_rotated.jpg", "portrait_rotated.jpg",
"pose.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg", "right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
@ -170,8 +175,11 @@ filegroup(
"mobilenet_v2_1.0_224.tflite", "mobilenet_v2_1.0_224.tflite",
"mobilenet_v3_small_100_224_embedder.tflite", "mobilenet_v3_small_100_224_embedder.tflite",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pose_detection.tflite",
"selfie_segm_128_128_3.tflite", "selfie_segm_128_128_3.tflite",
"selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3.tflite",
"selfie_segmentation.tflite",
"selfie_segmentation_landscape.tflite",
], ],
) )
@ -197,6 +205,7 @@ filegroup(
"portrait_expected_face_landmarks.pbtxt", "portrait_expected_face_landmarks.pbtxt",
"portrait_expected_face_landmarks_with_attention.pbtxt", "portrait_expected_face_landmarks_with_attention.pbtxt",
"portrait_rotated_expected_detection.pbtxt", "portrait_rotated_expected_detection.pbtxt",
"pose_expected_detection.pbtxt",
"thumb_up_landmarks.pbtxt", "thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt",
"victory_landmarks.pbtxt", "victory_landmarks.pbtxt",

View File

@ -0,0 +1,27 @@
# proto-file: mediapipe/framework/formats/detection.proto
# proto-message: Detection
location_data {
format: BOUNDING_BOX
bounding_box {
xmin: 397
ymin: 198
width: 199
height: 199
}
relative_keypoints {
x: 0.4879558
y: 0.7013345
}
relative_keypoints {
x: 0.48453212
y: 0.32265592
}
relative_keypoints {
x: 0.4992165
y: 0.4854874
}
relative_keypoints {
x: 0.50227845
y: 0.159788
}
}

View File

@ -24,6 +24,7 @@ VISION_LIBS = [
"//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_classifier",
"//mediapipe/tasks/web/vision/image_embedder", "//mediapipe/tasks/web/vision/image_embedder",
"//mediapipe/tasks/web/vision/image_segmenter", "//mediapipe/tasks/web/vision/image_segmenter",
"//mediapipe/tasks/web/vision/interactive_segmenter",
"//mediapipe/tasks/web/vision/object_detector", "//mediapipe/tasks/web/vision/object_detector",
] ]

View File

@ -75,6 +75,24 @@ imageSegmenter.segment(image, (masks, width, height) => {
}); });
``` ```
## Interactive Segmentation
The MediaPipe Interactive Segmenter lets you select a region of interest to
segment an image by.
```
const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
);
const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath(
vision, "model.tflite"
);
const image = document.getElementById("image") as HTMLImageElement;
interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } },
(masks, width, height) => { ... }
);
```
## Object Detection ## Object Detection
The MediaPipe Object Detector task lets you detect the presence and location of The MediaPipe Object Detector task lets you detect the presence and location of

View File

@ -20,6 +20,7 @@ import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/ha
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
import {InteractiveSegmenter as InteractiveSegmenterImpl} from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
// Declare the variables locally so that Rollup in OSS includes them explicitly // Declare the variables locally so that Rollup in OSS includes them explicitly
@ -30,6 +31,7 @@ const HandLandmarker = HandLandmarkerImpl;
const ImageClassifier = ImageClassifierImpl; const ImageClassifier = ImageClassifierImpl;
const ImageEmbedder = ImageEmbedderImpl; const ImageEmbedder = ImageEmbedderImpl;
const ImageSegmenter = ImageSegementerImpl; const ImageSegmenter = ImageSegementerImpl;
const InteractiveSegmenter = InteractiveSegmenterImpl;
const ObjectDetector = ObjectDetectorImpl; const ObjectDetector = ObjectDetectorImpl;
export { export {
@ -39,5 +41,6 @@ export {
ImageClassifier, ImageClassifier,
ImageEmbedder, ImageEmbedder,
ImageSegmenter, ImageSegmenter,
InteractiveSegmenter,
ObjectDetector ObjectDetector
}; };

View File

@ -0,0 +1,62 @@
# This contains the MediaPipe Interactive Segmenter Task.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "interactive_segmenter",
srcs = ["interactive_segmenter.ts"],
deps = [
":interactive_segmenter_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:keypoint",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/util:color_jspb_proto",
"//mediapipe/util:render_data_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)
mediapipe_ts_declaration(
name = "interactive_segmenter_types",
srcs = ["interactive_segmenter_options.d.ts"],
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "interactive_segmenter_test_lib",
testonly = True,
srcs = [
"interactive_segmenter_test.ts",
],
deps = [
":interactive_segmenter",
":interactive_segmenter_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/util:render_data_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
],
)
jasmine_node_test(
name = "interactive_segmenter_test",
tags = ["nomsan"],
deps = [":interactive_segmenter_test_lib"],
)

View File

@ -0,0 +1,306 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {Color as ColorProto} from '../../../../util/color_pb';
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
export * from './interactive_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
export {ImageSource};
const IMAGE_IN_STREAM = 'image_in';
const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in';
const IMAGE_OUT_STREAM = 'image_out';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Performs interactive segmentation on images.
*
* Users can represent user interaction through `RegionOfInterest`, which gives
* a hint to InteractiveSegmenter to perform segmentation focusing on the given
* region of interest.
*
* The API expects a TFLite model with mandatory TFLite Model Metadata.
*
* Input tensor:
* (kTfLiteUInt8/kTfLiteFloat32)
* - image input of size `[batch x height x width x channels]`.
* - batch inference is not supported (`batch` is required to be 1).
* - RGB inputs is supported (`channels` is required to be 3).
* - if type is kTfLiteFloat32, NormalizationOptions are required to be
* attached to the metadata for input normalization.
* Output tensors:
* (kTfLiteUInt8/kTfLiteFloat32)
* - list of segmented masks.
* - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
* - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
* `channels`.
* - batch is always 1
*/
export class InteractiveSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {};
private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto;
/**
* Initializes the Wasm runtime and creates a new interactive segmenter from
* the provided options.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param interactiveSegmenterOptions The options for the Interactive
* Segmenter. Note that either a path to the model asset or a model buffer
* needs to be provided (via `baseOptions`).
* @return A new `InteractiveSegmenter`.
*/
static createFromOptions(
wasmFileset: WasmFileset,
interactiveSegmenterOptions: InteractiveSegmenterOptions):
Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
interactiveSegmenterOptions);
}
/**
* Initializes the Wasm runtime and creates a new interactive segmenter based
* on the provided model asset buffer.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
* @return A new `InteractiveSegmenter`.
*/
static createFromModelBuffer(
wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new interactive segmenter based
* on the path to the model asset.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
* @return A new `InteractiveSegmenter`.
*/
static createFromModelPath(
wasmFileset: WasmFileset,
modelAssetPath: string): Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetPath}});
}
/** @hideconstructor */
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_IN_STREAM,
NORM_RECT_IN_STREAM, /* roiAllowed= */ false);
this.options = new ImageSegmenterGraphOptionsProto();
this.segmenterOptions = new SegmenterOptionsProto();
this.options.setSegmenterOptions(this.segmenterOptions);
this.options.setBaseOptions(new BaseOptionsProto());
}
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/**
* Sets new options for the interactive segmenter.
*
* Calling `setOptions()` with a subset of options only affects those
* options. You can reset an option back to its default value by
* explicitly setting it to `undefined`.
*
* @param options The options for the interactive segmenter.
* @return A Promise that resolves when the settings have been applied.
*/
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
if (options.outputType === 'CONFIDENCE_MASK') {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
} else {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
}
return super.applyOptions(options);
}
/**
* Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a
* user's region of interest for segmentation.
*
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
* of images that represent per-category segmented image mask. If the
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
* images that contains only one confidence image mask. The method returns
* synchronously once the callback returns.
*
* @param image An image to process.
* @param roi The region of interest for segmentation.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segment(
image: ImageSource, roi: RegionOfInterest,
callback: SegmentationMaskCallback): void;
/**
* Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a
* user's region of interest for segmentation.
*
* The 'image_processing_options' parameter can be used to specify the
* rotation to apply to the image before performing segmentation, by setting
* its 'rotationDegrees' field. Note that specifying a region-of-interest
* using the 'regionOfInterest' field is NOT supported and will result in an
* error.
*
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
* of images that represent per-category segmented image mask. If the
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
* images that contains only one confidence image mask. The method returns
* synchronously once the callback returns.
*
* @param image An image to process.
* @param roi The region of interest for segmentation.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segment(
image: ImageSource, roi: RegionOfInterest,
imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void;
segment(
image: ImageSource, roi: RegionOfInterest,
imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
}
/** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_IN_STREAM);
graphConfig.addInputStream(ROI_IN_STREAM);
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
ImageSegmenterGraphOptionsProto.ext, this.options);
const segmenterNode = new CalculatorGraphConfig.Node();
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener(
IMAGE_OUT_STREAM, (masks, timestamp) => {
if (masks.length === 0) {
this.userCallback([], 0, 0);
} else {
this.userCallback(
masks.map(m => m.data), masks[0].width, masks[0].height);
}
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
/**
* Converts the user-facing RegionOfInterest message to the RenderData proto
* and sends it to the graph
*/
private processRenderData(roi: RegionOfInterest, timestamp: number): void {
const renderData = new RenderDataProto();
const renderAnnotation = new RenderAnnotationProto();
const color = new ColorProto();
color.setR(255);
renderAnnotation.setColor(color);
const point = new RenderAnnotationProto.Point();
point.setNormalized(true);
point.setX(roi.keypoint.x);
point.setY(roi.keypoint.y);
renderAnnotation.setPoint(point);
renderData.addRenderAnnotations(renderAnnotation);
this.graphRunner.addProtoToStream(
renderData.serializeBinary(), 'mediapipe.RenderData', ROI_IN_STREAM,
timestamp);
}
}

View File

@ -0,0 +1,36 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
/** Options to configure the MediaPipe Interactive Segmenter Task */
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
/**
* The output type of segmentation results.
*
* The two supported modes are:
* - Category Mask: Gives a single output mask where each pixel represents
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
}

View File

@ -0,0 +1,214 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter';
const ROI: RegionOfInterest = {
keypoint: {x: 0.1, y: 0.2}
};
class InteractiveSegmenterFake extends InteractiveSegmenter implements
MediapipeTasksFake {
calculatorName =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
imageVectorListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('image_out');
this.imageVectorListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
spyOn(this.graphRunner, 'addProtoToStream')
.and.callFake((data, protoName, stream) => {
if (stream === 'roi_in') {
expect(protoName).toEqual('mediapipe.RenderData');
this.lastRoi = RenderDataProto.deserializeBinary(data);
}
});
}
}
describe('InteractiveSegmenter', () => {
let interactiveSegmenter: InteractiveSegmenterFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
interactiveSegmenter = new InteractiveSegmenterFake();
await interactiveSegmenter.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {
verifyGraph(interactiveSegmenter);
verifyListenersRegistered(interactiveSegmenter);
});
it('reloads graph when settings are changed', async () => {
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
verifyListenersRegistered(interactiveSegmenter);
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
verifyListenersRegistered(interactiveSegmenter);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await interactiveSegmenter.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
interactiveSegmenter,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */
[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
describe('setOptions()', () => {
const fieldPath = ['segmenterOptions', 'outputType'];
it(`can set outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
});
it(`can clear outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
await interactiveSegmenter.setOptions({outputType: undefined});
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
});
});
it('doesn\'t support region of interest', () => {
expect(() => {
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('sends region-of-interest', (done) => {
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.lastRoi).toBeDefined();
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
.toEqual(jasmine.objectContaining({
color: {r: 255, b: undefined, g: undefined},
}));
done();
});
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
});
it('supports category masks', (done) => {
const mask = new Uint8Array([1, 2, 3, 4]);
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter);
interactiveSegmenter.imageVectorListener!(
[
{data: mask, width: 2, height: 2},
],
/* timestamp= */ 1337);
});
// Invoke the image segmenter
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, (masks, width, height) => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(masks).toHaveSize(1);
expect(masks[0]).toEqual(mask);
expect(width).toEqual(2);
expect(height).toEqual(2);
done();
});
});
it('supports confidence masks', async () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter);
interactiveSegmenter.imageVectorListener!(
[
{data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, (masks, width, height) => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(masks).toHaveSize(2);
expect(masks[0]).toEqual(mask1);
expect(masks[1]).toEqual(mask2);
expect(width).toEqual(2);
expect(height).toEqual(2);
resolve();
});
});
});
});

View File

@ -20,4 +20,5 @@ export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_classifier/image_classifier';
export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/image_embedder/image_embedder';
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
export * from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
export * from '../../../tasks/web/vision/object_detector/object_detector'; export * from '../../../tasks/web/vision/object_detector/object_detector';

6
third_party/BUILD vendored
View File

@ -169,7 +169,11 @@ cmake_external(
"-lm", "-lm",
"-lpthread", "-lpthread",
"-lrt", "-lrt",
], ] + select({
"//mediapipe:ios": ["-framework Cocoa"],
"//mediapipe:macos": ["-framework Cocoa"],
"//conditions:default": [],
}),
shared_libraries = select({ shared_libraries = select({
"@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES], "@bazel_tools//src/conditions:darwin": ["libopencv_%s.%s.dylib" % (module, OPENCV_SO_VERSION) for module in OPENCV_MODULES],
# Only the shared objects listed here will be linked in the directory # Only the shared objects listed here will be linked in the directory

View File

@ -72,8 +72,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_BUILD_orig", name = "com_google_mediapipe_BUILD_orig",
sha256 = "64d5343a6a5f9be06db0a5074a2260f9ae63a989fe01702832cd215680dc19c1", sha256 = "d86b98b82e00dd87cd46bd1429bf5eaa007b500c1a24d9316b73309f2e6c8df8",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678323576393653"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678737479599640"],
) )
http_file( http_file(
@ -208,10 +208,34 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/corrupted_mobilenet_v1_0.25_224_1_default_1.tflite?generation=1661875706780536"], urls = ["https://storage.googleapis.com/mediapipe-assets/corrupted_mobilenet_v1_0.25_224_1_default_1.tflite?generation=1661875706780536"],
) )
http_file(
name = "com_google_mediapipe_deeplabv3_json",
sha256 = "f299835bd9ea1cceb25fdf40a761a22716cbd20025cd67c365a860527f178b7f",
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.json?generation=1678818040715103"],
)
http_file( http_file(
name = "com_google_mediapipe_deeplabv3_tflite", name = "com_google_mediapipe_deeplabv3_tflite",
sha256 = "9711334db2b01d5894feb8ed0f5cb3e97d125b8d229f8d8692f625801818f5ef", sha256 = "5faed2c653905d3e22a8f6f29ee198da84e9b0e7936a207bf431f17f6b4d87ff",
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"], urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1678775085237701"],
)
http_file(
name = "com_google_mediapipe_deeplabv3_with_activation_json",
sha256 = "a7633476d02f970db3cc30f5f027bcb608149e02207b2ccae36a4b69d730c82c",
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_with_activation.json?generation=1678818047050984"],
)
http_file(
name = "com_google_mediapipe_deeplabv3_without_labels_json",
sha256 = "7d045a583a4046f17a52d2078b0175607a45ed0cc187558325f9c66534c08401",
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_labels.json?generation=1678818050191996"],
)
http_file(
name = "com_google_mediapipe_deeplabv3_without_metadata_tflite",
sha256 = "68a539782c2c6a72f8aac3724600124a85ed977162b44e84cbae5db717c933c6",
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_metadata.tflite?generation=1678818053623010"],
) )
http_file( http_file(
@ -390,8 +414,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_hair_segmentation_tflite", name = "com_google_mediapipe_hair_segmentation_tflite",
sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711", sha256 = "7cbddcfe6f6e10c3e0a509eb2e14225fda5c0de6c35e2e8c6ca8e3971988fc17",
urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"], urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678775089064550"],
) )
http_file( http_file(
@ -823,7 +847,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt", name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16", sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678737486927530"],
) )
http_file( http_file(
@ -864,8 +888,20 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_pose_detection_tflite", name = "com_google_mediapipe_pose_detection_tflite",
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1661875889147923"], urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1678737489600422"],
)
http_file(
name = "com_google_mediapipe_pose_expected_detection_pbtxt",
sha256 = "e0d40e98dd5320a780a642c336d0c8720243ac5bcc0e39c4061ad970a503ae24",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"],
)
http_file(
name = "com_google_mediapipe_pose_jpg",
sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b",
urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"],
) )
http_file( http_file(
@ -964,6 +1000,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_input_rotation0.jpg?generation=1661875914048401"], urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_input_rotation0.jpg?generation=1661875914048401"],
) )
http_file(
name = "com_google_mediapipe_segmentation_mask_meta_json",
sha256 = "4294d53b309c1fbe38a5184de4057576c3dec14e07d16491f1dd459ac9116ab3",
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_mask_meta.json?generation=1678818065134737"],
)
http_file(
name = "com_google_mediapipe_segmenter_labelmap_txt",
sha256 = "d9efa78274f1799ddbcab1f87263e19dae338c1697de47a5b270c9526c45d364",
urls = ["https://storage.googleapis.com/mediapipe-assets/segmenter_labelmap.txt?generation=1678818068181025"],
)
http_file( http_file(
name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg", name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg",
sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e", sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e",
@ -972,8 +1020,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_selfie_segm_128_128_3_tflite", name = "com_google_mediapipe_selfie_segm_128_128_3_tflite",
sha256 = "bb154f248543c0738e32f1c74375245651351a84746dc21f10bdfaabd8fae4ca", sha256 = "8322982866488b063af6531b1d16ac27c7bf404135b7905f20aaf5e6af7aa45b",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1661875919964123"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1678775097370282"],
) )
http_file( http_file(
@ -984,20 +1032,20 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_selfie_segm_144_256_3_tflite", name = "com_google_mediapipe_selfie_segm_144_256_3_tflite",
sha256 = "5c770b8834ad50586599eae7710921be09d356898413fc0bf37a9458da0610eb", sha256 = "f16a9551a408edeadd53f70d1d2911fc20f9f9de7a394129a268ca9faa2d6a08",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1661875925519713"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1678775099616375"],
) )
http_file( http_file(
name = "com_google_mediapipe_selfie_segmentation_landscape_tflite", name = "com_google_mediapipe_selfie_segmentation_landscape_tflite",
sha256 = "4aafe6223bb8dac6fac8ca8ed56852870a33051ef3f6238822d282a109962894", sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1661875928328455"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"],
) )
http_file( http_file(
name = "com_google_mediapipe_selfie_segmentation_tflite", name = "com_google_mediapipe_selfie_segmentation_tflite",
sha256 = "8d13b7fae74af625c641226813616a2117bd6bca19eb3b75574621fc08557f27", sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c",
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"],
) )
http_file( http_file(
@ -1224,8 +1272,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_object_detection_saved_model_README_md", name = "com_google_mediapipe_object_detection_saved_model_README_md",
sha256 = "fe163cf12fbd017738a2fd360c03d223e964ba6404ac75c635f5918784e9c34d", sha256 = "acc23dee09f69210717ac060035c844ba902e8271486f1086f29fb156c236690",
urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1661875995856372"], urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1678737498915254"],
) )
http_file( http_file(