Merge branch 'ios-normalized-keypoint-hash' into ios-async-calls-fixes
This commit is contained in:
commit
db5fd168b6
|
@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input,
|
||||||
static_cast<int>(range_max) == 255);
|
static_cast<int>(range_max) == 255);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_frame = input.GetGpuBuffer().GetReadView<FrameBuffer>();
|
auto input_frame =
|
||||||
|
input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView<FrameBuffer>();
|
||||||
const auto& output_shape = output_tensor.shape();
|
const auto& output_shape = output_tensor.shape();
|
||||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||||
FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2],
|
FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2],
|
||||||
|
|
|
@ -1285,12 +1285,14 @@ cc_library(
|
||||||
srcs = ["flat_color_image_calculator.cc"],
|
srcs = ["flat_color_image_calculator.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":flat_color_image_calculator_cc_proto",
|
":flat_color_image_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_contract",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:image_frame_opencv",
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/util:color_cc_proto",
|
"//mediapipe/util:color_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -15,14 +15,13 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_contract.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -32,6 +31,7 @@ namespace {
|
||||||
using ::mediapipe::api2::Input;
|
using ::mediapipe::api2::Input;
|
||||||
using ::mediapipe::api2::Node;
|
using ::mediapipe::api2::Node;
|
||||||
using ::mediapipe::api2::Output;
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::SideOutput;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A calculator for generating an image filled with a single color.
|
// A calculator for generating an image filled with a single color.
|
||||||
|
@ -45,7 +45,8 @@ using ::mediapipe::api2::Output;
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// IMAGE (Image)
|
// IMAGE (Image)
|
||||||
// Image filled with the requested color.
|
// Image filled with the requested color. Can be either an output_stream
|
||||||
|
// or an output_side_packet.
|
||||||
//
|
//
|
||||||
// Example useage:
|
// Example useage:
|
||||||
// node {
|
// node {
|
||||||
|
@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
||||||
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
||||||
static constexpr Output<Image> kOutImage{"IMAGE"};
|
static constexpr Output<Image>::Optional kOutImage{"IMAGE"};
|
||||||
|
static constexpr SideOutput<Image>::Optional kOutSideImage{"IMAGE"};
|
||||||
|
|
||||||
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
|
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node {
|
||||||
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
||||||
<< "Either set COLOR input stream, or set through options";
|
<< "Either set COLOR input stream, or set through options";
|
||||||
|
|
||||||
|
RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected())
|
||||||
|
<< "Set IMAGE either as output stream, or as output side packet";
|
||||||
|
|
||||||
|
RET_CHECK(!kOutSideImage(cc).IsConnected() ||
|
||||||
|
(options.has_output_height() && options.has_output_width()))
|
||||||
|
<< "Set size through options, when setting IMAGE as output side packet";
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node {
|
||||||
absl::Status Process(CalculatorContext* cc) override;
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::optional<std::shared_ptr<ImageFrame>> CreateOutputFrame(
|
||||||
|
CalculatorContext* cc);
|
||||||
|
|
||||||
bool use_dimension_from_option_ = false;
|
bool use_dimension_from_option_ = false;
|
||||||
bool use_color_from_option_ = false;
|
bool use_color_from_option_ = false;
|
||||||
};
|
};
|
||||||
|
@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
|
||||||
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
||||||
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
||||||
use_color_from_option_ = !kInColor(cc).IsConnected();
|
use_color_from_option_ = !kInColor(cc).IsConnected();
|
||||||
|
|
||||||
|
if (!kOutImage(cc).IsConnected()) {
|
||||||
|
std::optional<std::shared_ptr<ImageFrame>> output_frame =
|
||||||
|
CreateOutputFrame(cc);
|
||||||
|
if (output_frame.has_value()) {
|
||||||
|
kOutSideImage(cc).Set(Image(output_frame.value()));
|
||||||
|
}
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
|
if (kOutImage(cc).IsConnected()) {
|
||||||
|
std::optional<std::shared_ptr<ImageFrame>> output_frame =
|
||||||
|
CreateOutputFrame(cc);
|
||||||
|
if (output_frame.has_value()) {
|
||||||
|
kOutImage(cc).Send(Image(output_frame.value()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<std::shared_ptr<ImageFrame>>
|
||||||
|
FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) {
|
||||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
|
||||||
int output_height = -1;
|
int output_height = -1;
|
||||||
|
@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
output_height = input_image.height();
|
output_height = input_image.height();
|
||||||
output_width = input_image.width();
|
output_width = input_image.width();
|
||||||
} else {
|
} else {
|
||||||
return absl::OkStatus();
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
Color color;
|
Color color;
|
||||||
|
@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
} else if (!kInColor(cc).IsEmpty()) {
|
} else if (!kInColor(cc).IsEmpty()) {
|
||||||
color = kInColor(cc).Get();
|
color = kInColor(cc).Get();
|
||||||
} else {
|
} else {
|
||||||
return absl::OkStatus();
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
|
|
||||||
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
||||||
|
|
||||||
kOutImage(cc).Send(Image(output_frame));
|
return output_frame;
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
output_side_packet: "IMAGE:out_packet"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 1
|
||||||
|
output_height: 1
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get<Image>();
|
||||||
|
EXPECT_EQ(image.width(), 1);
|
||||||
|
EXPECT_EQ(image.height(), 1);
|
||||||
|
auto image_frame = image.GetImageFrameSharedPtr();
|
||||||
|
const uint8_t* pixel_data = image_frame->PixelData();
|
||||||
|
EXPECT_EQ(pixel_data[0], 100);
|
||||||
|
EXPECT_EQ(pixel_data[1], 200);
|
||||||
|
EXPECT_EQ(pixel_data[2], 255);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
||||||
CalculatorRunner runner(R"pb(
|
CalculatorRunner runner(R"pb(
|
||||||
calculator: "FlatColorImageCalculator"
|
calculator: "FlatColorImageCalculator"
|
||||||
|
@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
|
||||||
HasSubstr("Either set COLOR input stream"));
|
HasSubstr("Either set COLOR input stream"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
output_side_packet: "IMAGE:out_packet"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 1
|
||||||
|
output_height: 1
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
ASSERT_THAT(
|
||||||
|
runner.Run().message(),
|
||||||
|
HasSubstr("Set IMAGE either as output stream, or as output side packet"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
output_side_packet: "IMAGE:out_packet"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Set size through options, when setting IMAGE as "
|
||||||
|
"output side packet"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) {
|
||||||
double target_aspect_ratio = 0.5;
|
double target_aspect_ratio = 0.5;
|
||||||
int expect_width = 14;
|
int expect_width = 14;
|
||||||
int expect_height = input_height;
|
int expect_height = input_height;
|
||||||
auto test_frame = absl::make_unique<ImageFrame>(/*format=*/ImageFormat::SRGB,
|
ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width,
|
||||||
input_width, input_height);
|
input_height);
|
||||||
|
cv::Mat mat = formats::MatView(&test_frame);
|
||||||
|
mat = cv::Scalar(0, 0, 0);
|
||||||
|
|
||||||
PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(),
|
PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(),
|
||||||
target_aspect_ratio,
|
target_aspect_ratio,
|
||||||
/*scale_to_multiple_of_two=*/true);
|
/*scale_to_multiple_of_two=*/true);
|
||||||
ImageFrame result_frame;
|
ImageFrame result_frame;
|
||||||
MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame));
|
MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame));
|
||||||
EXPECT_EQ(result_frame.Width(), expect_width);
|
EXPECT_EQ(result_frame.Width(), expect_width);
|
||||||
EXPECT_EQ(result_frame.Height(), expect_height);
|
EXPECT_EQ(result_frame.Height(), expect_height);
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,11 +113,11 @@ class Image {
|
||||||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
// Get a GPU view. Automatically uploads from CPU if needed.
|
// Provides access to the underlying GpuBuffer storage.
|
||||||
const mediapipe::GpuBuffer GetGpuBuffer() const {
|
// Automatically uploads from CPU to GPU if needed and requested through the
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
// `upload_to_gpu` argument.
|
||||||
if (use_gpu_ == false) ConvertToGpu();
|
const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const {
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
if (!use_gpu_ && upload_to_gpu) ConvertToGpu();
|
||||||
return gpu_buffer_;
|
return gpu_buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
31
mediapipe/framework/port/drishti_proto_alias_rules.bzl
Normal file
31
mediapipe/framework/port/drishti_proto_alias_rules.bzl
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly."""
|
||||||
|
|
||||||
|
def _copy_header_impl(ctx):
|
||||||
|
source = ctx.attr.source.replace("//", "").replace(":", "/")
|
||||||
|
files = []
|
||||||
|
for dep in ctx.attr.deps:
|
||||||
|
for header in dep[CcInfo].compilation_context.direct_headers:
|
||||||
|
if (header.short_path == source):
|
||||||
|
files.append(header)
|
||||||
|
if len(files) != 1:
|
||||||
|
fail("Expected exactly 1 source, got ", str(files))
|
||||||
|
dest_file = ctx.actions.declare_file(ctx.attr.filename)
|
||||||
|
|
||||||
|
# Use expand_template() with no substitutions as a simple copier.
|
||||||
|
ctx.actions.expand_template(
|
||||||
|
template = files[0],
|
||||||
|
output = dest_file,
|
||||||
|
substitutions = {},
|
||||||
|
)
|
||||||
|
return [DefaultInfo(files = depset([dest_file]))]
|
||||||
|
|
||||||
|
copy_header = rule(
|
||||||
|
implementation = _copy_header_impl,
|
||||||
|
attrs = {
|
||||||
|
"filename": attr.string(),
|
||||||
|
"source": attr.string(),
|
||||||
|
"deps": attr.label_list(providers = [CcInfo]),
|
||||||
|
},
|
||||||
|
output_to_genfiles = True,
|
||||||
|
outputs = {"out": "%{filename}"},
|
||||||
|
)
|
|
@ -791,6 +791,7 @@ cc_library(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@stblib//:stb_image",
|
"@stblib//:stb_image",
|
||||||
"@stblib//:stb_image_write",
|
"@stblib//:stb_image_write",
|
||||||
],
|
],
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
@ -311,6 +312,13 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
|
||||||
// Returns the path to the output if successful.
|
// Returns the path to the output if successful.
|
||||||
absl::StatusOr<std::string> SavePngTestOutput(
|
absl::StatusOr<std::string> SavePngTestOutput(
|
||||||
const mediapipe::ImageFrame& image, absl::string_view prefix) {
|
const mediapipe::ImageFrame& image, absl::string_view prefix) {
|
||||||
|
absl::flat_hash_set<ImageFormat::Format> supported_formats = {
|
||||||
|
ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA,
|
||||||
|
ImageFormat::LAB8, ImageFormat::SBGRA};
|
||||||
|
if (!supported_formats.contains(image.Format())) {
|
||||||
|
return absl::CancelledError(
|
||||||
|
absl::StrFormat("Format %d can not be saved to PNG.", image.Format()));
|
||||||
|
}
|
||||||
std::string now_string = absl::FormatTime(absl::Now());
|
std::string now_string = absl::FormatTime(absl::Now());
|
||||||
std::string output_relative_path =
|
std::string output_relative_path =
|
||||||
absl::StrCat(prefix, "_", now_string, ".png");
|
absl::StrCat(prefix, "_", now_string, ".png");
|
||||||
|
|
|
@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
self._num_classes = num_classes
|
self._num_classes = num_classes
|
||||||
self._model = self._build_model()
|
self._model = self._build_model()
|
||||||
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||||
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
checkpoint_file = os.path.join(
|
||||||
|
checkpoint_folder, self._model_spec.checkpoint_name
|
||||||
|
)
|
||||||
self.load_checkpoint(checkpoint_file)
|
self.load_checkpoint(checkpoint_file)
|
||||||
self._model.summary()
|
self._model.summary()
|
||||||
self.loss_trackers = [
|
self.loss_trackers = [
|
||||||
|
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||||
),
|
),
|
||||||
backbone=configs.backbones.Backbone(
|
backbone=configs.backbones.Backbone(
|
||||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
type='mobilenet',
|
||||||
|
mobilenet=configs.backbones.MobileNet(
|
||||||
|
model_id=self._model_spec.model_id
|
||||||
|
),
|
||||||
),
|
),
|
||||||
decoder=configs.decoders.Decoder(
|
decoder=configs.decoders.Decoder(
|
||||||
type='fpn',
|
type='fpn',
|
||||||
|
|
|
@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||||
|
'object_detector/mobilenetmultiavg',
|
||||||
|
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelSpec(object):
|
class ModelSpec(object):
|
||||||
|
@ -38,13 +44,25 @@ class ModelSpec(object):
|
||||||
stddev_rgb = (127.5,)
|
stddev_rgb = (127.5,)
|
||||||
|
|
||||||
downloaded_files: file_util.DownloadedFiles
|
downloaded_files: file_util.DownloadedFiles
|
||||||
|
checkpoint_name: str
|
||||||
input_image_shape: List[int]
|
input_image_shape: List[int]
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
mobilenet_v2_spec = functools.partial(
|
mobilenet_v2_spec = functools.partial(
|
||||||
ModelSpec,
|
ModelSpec,
|
||||||
downloaded_files=MOBILENET_V2_FILES,
|
downloaded_files=MOBILENET_V2_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
input_image_shape=[256, 256, 3],
|
input_image_shape=[256, 256, 3],
|
||||||
|
model_id='MobileNetV2',
|
||||||
|
)
|
||||||
|
|
||||||
|
mobilenet_multi_avg_spec = functools.partial(
|
||||||
|
ModelSpec,
|
||||||
|
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||||
|
checkpoint_name='ckpt-277200',
|
||||||
|
input_image_shape=[256, 256, 3],
|
||||||
|
model_id='MobileNetMultiAVG',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
|
||||||
"""Predefined object detector model specs supported by Model Maker."""
|
"""Predefined object detector model specs supported by Model Maker."""
|
||||||
|
|
||||||
MOBILENET_V2 = mobilenet_v2_spec
|
MOBILENET_V2 = mobilenet_v2_spec
|
||||||
|
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||||
|
|
|
@ -93,3 +93,8 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "transformer_params_proto",
|
||||||
|
srcs = ["transformer_params.proto"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||||
|
option java_outer_classname = "TransformerParametersProto";
|
||||||
|
|
||||||
|
// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf)
|
||||||
|
message TransformerParameters {
|
||||||
|
// Batch size of tensors.
|
||||||
|
int32 batch_size = 1;
|
||||||
|
|
||||||
|
// Maximum sequence length of the input/output tensor.
|
||||||
|
int32 max_seq_length = 2;
|
||||||
|
|
||||||
|
// Embedding dimension (or model dimension), `d_model` in the paper.
|
||||||
|
// `d_k` == `d_v` == `d_model`/`h`.
|
||||||
|
int32 embedding_dim = 3;
|
||||||
|
|
||||||
|
// Hidden dimension used in the feedforward layer, `d_ff` in the paper.
|
||||||
|
int32 hidden_dimension = 4;
|
||||||
|
|
||||||
|
// Head dimension, `d_k` or `d_v` in the paper.
|
||||||
|
int32 head_dimension = 5;
|
||||||
|
|
||||||
|
// Number of heads, `h` in the paper.
|
||||||
|
int32 num_heads = 6;
|
||||||
|
|
||||||
|
// Number of stacked transformers, `N` in the paper.
|
||||||
|
int32 num_stacks = 7;
|
||||||
|
}
|
|
@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
||||||
auto matrix = preprocessing.Out(kMatrixTag);
|
auto matrix = preprocessing.Out(kMatrixTag);
|
||||||
auto image_size = preprocessing.Out(kImageSizeTag);
|
auto image_size = preprocessing.Out(kImageSizeTag);
|
||||||
|
|
||||||
// Face detection model inferece.
|
// Face detection model inference.
|
||||||
auto& inference = AddInference(
|
auto& inference = AddInference(
|
||||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||||
preprocessed_tensors >> inference.In(kTensorsTag);
|
preprocessed_tensors >> inference.In(kTensorsTag);
|
||||||
|
|
|
@ -199,7 +199,9 @@ void ConfigureTensorsToImageCalculator(
|
||||||
// STYLIZED_IMAGE - mediapipe::Image
|
// STYLIZED_IMAGE - mediapipe::Image
|
||||||
// The face stylization output image.
|
// The face stylization output image.
|
||||||
// FACE_ALIGNMENT - mediapipe::Image
|
// FACE_ALIGNMENT - mediapipe::Image
|
||||||
// The face alignment output image.
|
// The aligned face image that is fed to the face stylization model to
|
||||||
|
// perform stylization. Also useful for preparing face stylization training
|
||||||
|
// data.
|
||||||
// IMAGE - mediapipe::Image
|
// IMAGE - mediapipe::Image
|
||||||
// The input image that the face landmarker runs on and has the pixel data
|
// The input image that the face landmarker runs on and has the pixel data
|
||||||
// stored on the target storage (CPU vs GPU).
|
// stored on the target storage (CPU vs GPU).
|
||||||
|
@ -211,6 +213,7 @@ void ConfigureTensorsToImageCalculator(
|
||||||
// input_stream: "NORM_RECT:norm_rect"
|
// input_stream: "NORM_RECT:norm_rect"
|
||||||
// output_stream: "IMAGE:image_out"
|
// output_stream: "IMAGE:image_out"
|
||||||
// output_stream: "STYLIZED_IMAGE:stylized_image"
|
// output_stream: "STYLIZED_IMAGE:stylized_image"
|
||||||
|
// output_stream: "FACE_ALIGNMENT:face_alignment_image"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext]
|
// [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext]
|
||||||
// {
|
// {
|
||||||
|
@ -248,7 +251,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
->mutable_face_landmarker_graph_options(),
|
->mutable_face_landmarker_graph_options(),
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
const ModelResources* face_stylizer_model_resources;
|
const ModelResources* face_stylizer_model_resources = nullptr;
|
||||||
if (output_stylized) {
|
if (output_stylized) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
const auto* model_resources,
|
const auto* model_resources,
|
||||||
|
@ -332,7 +335,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
auto face_rect = face_to_rect.Out(kNormRectTag);
|
auto face_rect = face_to_rect.Out(kNormRectTag);
|
||||||
|
|
||||||
std::optional<Source<Image>> face_alignment;
|
std::optional<Source<Image>> face_alignment;
|
||||||
// Output face alignment only.
|
// Output aligned face only.
|
||||||
// In this case, the face stylization model inference is not required.
|
// In this case, the face stylization model inference is not required.
|
||||||
// However, to keep consistent with the inference preprocessing steps, the
|
// However, to keep consistent with the inference preprocessing steps, the
|
||||||
// ImageToTensorCalculator is still used to perform image rotation,
|
// ImageToTensorCalculator is still used to perform image rotation,
|
||||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
constexpr absl::string_view kSubgraphTypeName{
|
constexpr absl::string_view kSubgraphTypeName{
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||||
|
|
||||||
|
using components::containers::NormalizedKeypoint;
|
||||||
|
|
||||||
using ::mediapipe::CalculatorGraphConfig;
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
case RegionOfInterest::Format::kUnspecified:
|
case RegionOfInterest::Format::kUnspecified:
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"RegionOfInterest format not specified");
|
"RegionOfInterest format not specified");
|
||||||
case RegionOfInterest::Format::kKeyPoint:
|
case RegionOfInterest::Format::kKeyPoint: {
|
||||||
RET_CHECK(roi.keypoint.has_value());
|
RET_CHECK(roi.keypoint.has_value());
|
||||||
auto* annotation = result.add_render_annotations();
|
auto* annotation = result.add_render_annotations();
|
||||||
annotation->mutable_color()->set_r(255);
|
annotation->mutable_color()->set_r(255);
|
||||||
|
@ -125,6 +128,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
||||||
point->set_y(roi.keypoint->y);
|
point->set_y(roi.keypoint->y);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
case RegionOfInterest::Format::kScribble: {
|
||||||
|
RET_CHECK(roi.scribble.has_value());
|
||||||
|
auto* annotation = result.add_render_annotations();
|
||||||
|
annotation->mutable_color()->set_r(255);
|
||||||
|
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
|
||||||
|
auto* point = annotation->mutable_scribble()->add_point();
|
||||||
|
point->set_normalized(true);
|
||||||
|
point->set_x(keypoint.x);
|
||||||
|
point->set_y(keypoint.y);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
return absl::UnimplementedError("Unrecognized format");
|
return absl::UnimplementedError("Unrecognized format");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,7 @@ struct RegionOfInterest {
|
||||||
enum class Format {
|
enum class Format {
|
||||||
kUnspecified = 0, // Format not specified.
|
kUnspecified = 0, // Format not specified.
|
||||||
kKeyPoint = 1, // Using keypoint to represent ROI.
|
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||||
|
kScribble = 2, // Using scribble to represent ROI.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Specifies the format used to specify the region-of-interest. Note that
|
// Specifies the format used to specify the region-of-interest. Note that
|
||||||
|
@ -61,8 +62,13 @@ struct RegionOfInterest {
|
||||||
Format format = Format::kUnspecified;
|
Format format = Format::kUnspecified;
|
||||||
|
|
||||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||||
// `format` is `KEYPOINT`.
|
// `format` is `kKeyPoint`.
|
||||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||||
|
|
||||||
|
// Represents the ROI in scribble format, this should be non-nullopt if
|
||||||
|
// `format` is `kScribble`.
|
||||||
|
std::optional<std::vector<components::containers::NormalizedKeypoint>>
|
||||||
|
scribble;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Performs interactive segmentation on images.
|
// Performs interactive segmentation on images.
|
||||||
|
|
|
@ -18,9 +18,12 @@ limitations under the License.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <variant>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
||||||
struct InteractiveSegmenterTestParams {
|
struct InteractiveSegmenterTestParams {
|
||||||
std::string test_name;
|
std::string test_name;
|
||||||
RegionOfInterest::Format format;
|
RegionOfInterest::Format format;
|
||||||
NormalizedKeypoint roi;
|
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
||||||
absl::string_view golden_mask_file;
|
absl::string_view golden_mask_file;
|
||||||
float similarity_threshold;
|
float similarity_threshold;
|
||||||
};
|
};
|
||||||
|
|
||||||
using SucceedSegmentationWithRoi =
|
class SucceedSegmentationWithRoi
|
||||||
::testing::TestWithParam<InteractiveSegmenterTestParams>;
|
: public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
|
||||||
|
public:
|
||||||
|
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
|
||||||
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
|
RegionOfInterest interaction_roi;
|
||||||
|
interaction_roi.format = params.format;
|
||||||
|
switch (params.format) {
|
||||||
|
case (RegionOfInterest::Format::kKeyPoint): {
|
||||||
|
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case (RegionOfInterest::Format::kScribble): {
|
||||||
|
interaction_roi.scribble =
|
||||||
|
std::get<std::vector<NormalizedKeypoint>>(params.roi);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
return absl::InvalidArgumentError("Unknown ROI format");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return interaction_roi;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||||
|
TestParamsToTaskOptions());
|
||||||
const InteractiveSegmenterTestParams& params = GetParam();
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
|
||||||
interaction_roi.format = params.format;
|
|
||||||
interaction_roi.keypoint = params.roi;
|
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
const auto& params = GetParam();
|
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||||
|
TestParamsToTaskOptions());
|
||||||
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||||
RegionOfInterest interaction_roi;
|
|
||||||
interaction_roi.format = params.format;
|
|
||||||
interaction_roi.keypoint = params.roi;
|
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||||
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
{// Keypoint input.
|
||||||
|
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||||
kGoldenMaskSimilarity}}),
|
kGoldenMaskSimilarity},
|
||||||
|
// Scribble input.
|
||||||
|
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
||||||
|
std::vector{NormalizedKeypoint{0.44, 0.70},
|
||||||
|
NormalizedKeypoint{0.44, 0.71},
|
||||||
|
NormalizedKeypoint{0.44, 0.72}},
|
||||||
|
kCatsAndDogsMaskDog1, 0.84f},
|
||||||
|
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
||||||
|
std::vector{NormalizedKeypoint{0.66, 0.66},
|
||||||
|
NormalizedKeypoint{0.66, 0.67},
|
||||||
|
NormalizedKeypoint{0.66, 0.68}},
|
||||||
|
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
info) { return info.param.test_name; });
|
info) { return info.param.test_name; });
|
||||||
|
|
||||||
|
|
|
@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||||
->mutable_model_asset(),
|
->mutable_model_asset(),
|
||||||
is_copy);
|
is_copy);
|
||||||
}
|
}
|
||||||
|
if (options->base_options().acceleration().has_gpu()) {
|
||||||
|
core::proto::Acceleration gpu_accel;
|
||||||
|
gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true);
|
||||||
|
pose_detector_graph_options->mutable_base_options()
|
||||||
|
->mutable_acceleration()
|
||||||
|
->CopyFrom(gpu_accel);
|
||||||
|
|
||||||
|
} else {
|
||||||
pose_detector_graph_options->mutable_base_options()
|
pose_detector_graph_options->mutable_base_options()
|
||||||
->mutable_acceleration()
|
->mutable_acceleration()
|
||||||
->CopyFrom(options->base_options().acceleration());
|
->CopyFrom(options->base_options().acceleration());
|
||||||
|
}
|
||||||
pose_detector_graph_options->mutable_base_options()->set_use_stream_mode(
|
pose_detector_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||||
options->base_options().use_stream_mode());
|
options->base_options().use_stream_mode());
|
||||||
auto* pose_landmarks_detector_graph_options =
|
auto* pose_landmarks_detector_graph_options =
|
||||||
|
|
|
@ -28,7 +28,12 @@
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement hash
|
- (NSUInteger)hash {
|
||||||
|
NSUInteger nonNullPropertiesHash =
|
||||||
|
@(self.location.x).hash ^ @(self.location.y).hash ^ @(self.score).hash;
|
||||||
|
|
||||||
|
return self.label ? nonNullPropertiesHash ^ self.label.hash : nonNullPropertiesHash;
|
||||||
|
}
|
||||||
|
|
||||||
- (BOOL)isEqual:(nullable id)object {
|
- (BOOL)isEqual:(nullable id)object {
|
||||||
if (!object) {
|
if (!object) {
|
||||||
|
|
|
@ -180,6 +180,7 @@ android_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"poselandmarker/PoseLandmarker.java",
|
"poselandmarker/PoseLandmarker.java",
|
||||||
"poselandmarker/PoseLandmarkerResult.java",
|
"poselandmarker/PoseLandmarkerResult.java",
|
||||||
|
"poselandmarker/PoseLandmarksConnections.java",
|
||||||
],
|
],
|
||||||
javacopts = [
|
javacopts = [
|
||||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
@ -212,6 +213,7 @@ android_library(
|
||||||
"handlandmarker/HandLandmark.java",
|
"handlandmarker/HandLandmark.java",
|
||||||
"handlandmarker/HandLandmarker.java",
|
"handlandmarker/HandLandmarker.java",
|
||||||
"handlandmarker/HandLandmarkerResult.java",
|
"handlandmarker/HandLandmarkerResult.java",
|
||||||
|
"handlandmarker/HandLandmarksConnections.java",
|
||||||
],
|
],
|
||||||
javacopts = [
|
javacopts = [
|
||||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
|
|
@ -77,11 +77,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
}
|
}
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
if (!normRectStreamName.isEmpty()) {
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName,
|
normRectStreamName,
|
||||||
runner
|
runner
|
||||||
.getPacketCreator()
|
.getPacketCreator()
|
||||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||||
|
}
|
||||||
return runner.process(inputPackets);
|
return runner.process(inputPackets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,11 +107,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
}
|
}
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
if (!normRectStreamName.isEmpty()) {
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName,
|
normRectStreamName,
|
||||||
runner
|
runner
|
||||||
.getPacketCreator()
|
.getPacketCreator()
|
||||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||||
|
}
|
||||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,11 +137,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
}
|
}
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
if (!normRectStreamName.isEmpty()) {
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName,
|
normRectStreamName,
|
||||||
runner
|
runner
|
||||||
.getPacketCreator()
|
.getPacketCreator()
|
||||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||||
|
}
|
||||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
/** Hand landmarks connection constants. */
|
||||||
|
public final class HandLandmarksConnections {
|
||||||
|
|
||||||
|
/** Value class representing hand landmarks connection. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class Connection {
|
||||||
|
static Connection create(int start, int end) {
|
||||||
|
return new AutoValue_HandLandmarksConnections_Connection(start, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract int start();
|
||||||
|
|
||||||
|
public abstract int end();
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_PALM_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(0, 1),
|
||||||
|
Connection.create(0, 5),
|
||||||
|
Connection.create(9, 13),
|
||||||
|
Connection.create(13, 17),
|
||||||
|
Connection.create(5, 9),
|
||||||
|
Connection.create(0, 17))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_THUMB_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(1, 2), Connection.create(2, 3), Connection.create(3, 4))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_INDEX_FINGER_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(5, 6), Connection.create(6, 7), Connection.create(7, 8))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_MIDDLE_FINGER_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(9, 10), Connection.create(10, 11), Connection.create(11, 12))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_RING_FINGER_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(13, 14),
|
||||||
|
Connection.create(14, 15),
|
||||||
|
Connection.create(15, 16))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_PINKY_FINGER_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(17, 18),
|
||||||
|
Connection.create(18, 19),
|
||||||
|
Connection.create(19, 20))));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> HAND_CONNECTIONS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
Stream.of(
|
||||||
|
HAND_PALM_CONNECTIONS.stream(),
|
||||||
|
HAND_THUMB_CONNECTIONS.stream(),
|
||||||
|
HAND_INDEX_FINGER_CONNECTIONS.stream(),
|
||||||
|
HAND_MIDDLE_FINGER_CONNECTIONS.stream(),
|
||||||
|
HAND_RING_FINGER_CONNECTIONS.stream(),
|
||||||
|
HAND_PINKY_FINGER_CONNECTIONS.stream())
|
||||||
|
.flatMap(i -> i)
|
||||||
|
.collect(Collectors.toSet()));
|
||||||
|
|
||||||
|
private HandLandmarksConnections() {}
|
||||||
|
}
|
|
@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
/** The Region-Of-Interest (ROI) to interact with. */
|
/** The Region-Of-Interest (ROI) to interact with. */
|
||||||
public static class RegionOfInterest {
|
public static class RegionOfInterest {
|
||||||
private NormalizedKeypoint keypoint;
|
private NormalizedKeypoint keypoint;
|
||||||
|
private List<NormalizedKeypoint> scribble;
|
||||||
|
|
||||||
private RegionOfInterest() {}
|
private RegionOfInterest() {}
|
||||||
|
|
||||||
|
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
roi.keypoint = keypoint;
|
roi.keypoint = keypoint;
|
||||||
return roi;
|
return roi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
|
||||||
|
* user wants to segment.
|
||||||
|
*/
|
||||||
|
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
|
||||||
|
RegionOfInterest roi = new RegionOfInterest();
|
||||||
|
roi.scribble = scribble;
|
||||||
|
return roi;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
.setX(roi.keypoint.x())
|
.setX(roi.keypoint.x())
|
||||||
.setY(roi.keypoint.y())))
|
.setY(roi.keypoint.y())))
|
||||||
.build();
|
.build();
|
||||||
|
} else if (roi.scribble != null) {
|
||||||
|
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
|
||||||
|
for (NormalizedKeypoint p : roi.scribble) {
|
||||||
|
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
.addRenderAnnotations(
|
||||||
|
RenderAnnotation.newBuilder()
|
||||||
|
.setColor(Color.newBuilder().setR(255))
|
||||||
|
.setScribble(scribbleBuilder))
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.poselandmarker;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
/** Pose landmarks connection constants. */
|
||||||
|
public final class PoseLandmarksConnections {
|
||||||
|
|
||||||
|
/** Value class representing pose landmarks connection. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class Connection {
|
||||||
|
static Connection create(int start, int end) {
|
||||||
|
return new AutoValue_PoseLandmarksConnections_Connection(start, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract int start();
|
||||||
|
|
||||||
|
public abstract int end();
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
|
public static final Set<Connection> POSE_LANDMARKS =
|
||||||
|
Collections.unmodifiableSet(
|
||||||
|
new HashSet<>(
|
||||||
|
Arrays.asList(
|
||||||
|
Connection.create(0, 1),
|
||||||
|
Connection.create(1, 2),
|
||||||
|
Connection.create(2, 3),
|
||||||
|
Connection.create(3, 7),
|
||||||
|
Connection.create(0, 4),
|
||||||
|
Connection.create(4, 5),
|
||||||
|
Connection.create(5, 6),
|
||||||
|
Connection.create(6, 8),
|
||||||
|
Connection.create(9, 10),
|
||||||
|
Connection.create(11, 12),
|
||||||
|
Connection.create(11, 13),
|
||||||
|
Connection.create(13, 15),
|
||||||
|
Connection.create(15, 17),
|
||||||
|
Connection.create(15, 19),
|
||||||
|
Connection.create(15, 21),
|
||||||
|
Connection.create(17, 19),
|
||||||
|
Connection.create(12, 14),
|
||||||
|
Connection.create(14, 16),
|
||||||
|
Connection.create(16, 18),
|
||||||
|
Connection.create(16, 20),
|
||||||
|
Connection.create(16, 22),
|
||||||
|
Connection.create(18, 20),
|
||||||
|
Connection.create(11, 23),
|
||||||
|
Connection.create(12, 24),
|
||||||
|
Connection.create(23, 24),
|
||||||
|
Connection.create(23, 25),
|
||||||
|
Connection.create(24, 26),
|
||||||
|
Connection.create(25, 27),
|
||||||
|
Connection.create(26, 28),
|
||||||
|
Connection.create(27, 29),
|
||||||
|
Connection.create(28, 30),
|
||||||
|
Connection.create(29, 31),
|
||||||
|
Connection.create(30, 32),
|
||||||
|
Connection.create(27, 31),
|
||||||
|
Connection.create(28, 32))));
|
||||||
|
|
||||||
|
private PoseLandmarksConnections() {}
|
||||||
|
}
|
|
@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
||||||
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
|
||||||
/** Test for {@link InteractiveSegmenter}. */
|
/** Test for {@link InteractiveSegmenter}. */
|
||||||
@RunWith(Suite.class)
|
@RunWith(Suite.class)
|
||||||
@SuiteClasses({
|
@SuiteClasses({
|
||||||
InteractiveSegmenterTest.General.class,
|
InteractiveSegmenterTest.KeypointRoi.class,
|
||||||
|
InteractiveSegmenterTest.ScribbleRoi.class,
|
||||||
})
|
})
|
||||||
public class InteractiveSegmenterTest {
|
public class InteractiveSegmenterTest {
|
||||||
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
||||||
|
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
|
||||||
private static final int MAGNIFICATION_FACTOR = 10;
|
private static final int MAGNIFICATION_FACTOR = 10;
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
public static final class General extends InteractiveSegmenterTest {
|
public static final class KeypointRoi extends InteractiveSegmenterTest {
|
||||||
@Test
|
@Test
|
||||||
public void segment_successWithCategoryMask() throws Exception {
|
public void segment_successWithCategoryMask() throws Exception {
|
||||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class ScribbleRoi extends InteractiveSegmenterTest {
|
||||||
|
@Test
|
||||||
|
public void segment_successWithCategoryMask() throws Exception {
|
||||||
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||||
|
final InteractiveSegmenter.RegionOfInterest roi =
|
||||||
|
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||||
|
InteractiveSegmenterOptions options =
|
||||||
|
InteractiveSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
|
.setOutputConfidenceMasks(false)
|
||||||
|
.setOutputCategoryMask(true)
|
||||||
|
.build();
|
||||||
|
InteractiveSegmenter imageSegmenter =
|
||||||
|
InteractiveSegmenter.createFromOptions(
|
||||||
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MPImage image = getImageFromAsset(inputImageName);
|
||||||
|
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||||
|
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void segment_successWithConfidenceMask() throws Exception {
|
||||||
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||||
|
final InteractiveSegmenter.RegionOfInterest roi =
|
||||||
|
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||||
|
InteractiveSegmenterOptions options =
|
||||||
|
InteractiveSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
|
.setOutputConfidenceMasks(true)
|
||||||
|
.setOutputCategoryMask(false)
|
||||||
|
.build();
|
||||||
|
InteractiveSegmenter imageSegmenter =
|
||||||
|
InteractiveSegmenter.createFromOptions(
|
||||||
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageSegmenterResult actualResult =
|
||||||
|
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||||
|
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||||
|
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||||
|
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
InputStream istr = assetManager.open(filePath);
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
|
|
@ -39,8 +39,10 @@ _Landmark = landmark_module.Landmark
|
||||||
class LandmarksDetectionResult:
|
class LandmarksDetectionResult:
|
||||||
"""Represents the landmarks detection result.
|
"""Represents the landmarks detection result.
|
||||||
|
|
||||||
Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A
|
Attributes:
|
||||||
list of `Category` objects. world_landmarks : A list of `Landmark` objects.
|
landmarks: A list of `NormalizedLandmark` objects.
|
||||||
|
categories: A list of `Category` objects.
|
||||||
|
world_landmarks: A list of `Landmark` objects.
|
||||||
rect: A `NormalizedRect` object.
|
rect: A `NormalizedRect` object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -49,3 +49,18 @@ py_test(
|
||||||
"//mediapipe/tasks/python/text:text_embedder",
|
"//mediapipe/tasks/python/text:text_embedder",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "language_detector_test",
|
||||||
|
srcs = ["language_detector_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:language_detector",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/text:language_detector",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
228
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
228
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for language detector."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import category
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.text import language_detector
|
||||||
|
|
||||||
|
LanguageDetectorResult = language_detector.LanguageDetectorResult
|
||||||
|
LanguageDetectorPrediction = (
|
||||||
|
language_detector.LanguageDetectorResult.Detection
|
||||||
|
)
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_Category = category.Category
|
||||||
|
_Classifications = classification_result_module.Classifications
|
||||||
|
_LanguageDetector = language_detector.LanguageDetector
|
||||||
|
_LanguageDetectorOptions = language_detector.LanguageDetectorOptions
|
||||||
|
|
||||||
|
_LANGUAGE_DETECTOR_MODEL = "language_detector.tflite"
|
||||||
|
_TEST_DATA_DIR = "mediapipe/tasks/testdata/text"
|
||||||
|
|
||||||
|
_SCORE_THRESHOLD = 0.3
|
||||||
|
_EN_TEXT = "To be, or not to be, that is the question"
|
||||||
|
_EN_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[LanguageDetectorPrediction("en", 0.999856)]
|
||||||
|
)
|
||||||
|
_FR_TEXT = (
|
||||||
|
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent."
|
||||||
|
)
|
||||||
|
_FR_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[LanguageDetectorPrediction("fr", 0.999781)]
|
||||||
|
)
|
||||||
|
_RU_TEXT = "это какой-то английский язык"
|
||||||
|
_RU_EXPECTED_RESULT = LanguageDetectorResult(
|
||||||
|
[LanguageDetectorPrediction("ru", 0.993362)]
|
||||||
|
)
|
||||||
|
_MIXED_TEXT = "分久必合合久必分"
|
||||||
|
_MIXED_EXPECTED_RESULT = LanguageDetectorResult([
|
||||||
|
LanguageDetectorPrediction("zh", 0.505424),
|
||||||
|
LanguageDetectorPrediction("ja", 0.481617),
|
||||||
|
])
|
||||||
|
_TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDetectorTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expect_language_detector_result_correct(
|
||||||
|
self,
|
||||||
|
actual_result: LanguageDetectorResult,
|
||||||
|
expect_result: LanguageDetectorResult,
|
||||||
|
):
|
||||||
|
for i, prediction in enumerate(actual_result.detections):
|
||||||
|
expected_prediction = expect_result.detections[i]
|
||||||
|
self.assertEqual(
|
||||||
|
prediction.language_code,
|
||||||
|
expected_prediction.language_code,
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
prediction.probability,
|
||||||
|
expected_prediction.probability,
|
||||||
|
delta=_TOLERANCE,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _LanguageDetector.create_from_model_path(self.model_path) as detector:
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Unable to open file at /path/to/invalid/model.tflite"
|
||||||
|
):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path="/path/to/invalid/model.tflite"
|
||||||
|
)
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
_LanguageDetector.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, "rb") as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _LanguageDetectorOptions(base_options=base_options)
|
||||||
|
detector = _LanguageDetector.create_from_options(options)
|
||||||
|
self.assertIsInstance(detector, _LanguageDetector)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||||
|
)
|
||||||
|
def test_detect(self, model_file_type, text, expected_result):
|
||||||
|
# Creates detector.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, "rb") as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError("model_file_type is invalid.")
|
||||||
|
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||||
|
)
|
||||||
|
detector = _LanguageDetector.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(text)
|
||||||
|
# Comparing results.
|
||||||
|
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||||
|
# Closes the detector explicitly when the detector is not used in
|
||||||
|
# a context.
|
||||||
|
detector.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||||
|
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||||
|
)
|
||||||
|
def test_detect_in_context(self, model_file_type, text, expected_result):
|
||||||
|
# Creates detector.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, "rb") as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError("model_file_type is invalid.")
|
||||||
|
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(text)
|
||||||
|
# Comparing results.
|
||||||
|
self._expect_language_detector_result_correct(
|
||||||
|
text_result, expected_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_allowlist_option(self):
|
||||||
|
# Creates detector.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options,
|
||||||
|
score_threshold=_SCORE_THRESHOLD,
|
||||||
|
category_allowlist=["ja"],
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(_MIXED_TEXT)
|
||||||
|
# Comparing results.
|
||||||
|
expected_result = LanguageDetectorResult(
|
||||||
|
[LanguageDetectorPrediction("ja", 0.481617)]
|
||||||
|
)
|
||||||
|
self._expect_language_detector_result_correct(
|
||||||
|
text_result, expected_result
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_denylist_option(self):
|
||||||
|
# Creates detector.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _LanguageDetectorOptions(
|
||||||
|
base_options=base_options,
|
||||||
|
score_threshold=_SCORE_THRESHOLD,
|
||||||
|
category_denylist=["ja"],
|
||||||
|
)
|
||||||
|
with _LanguageDetector.create_from_options(options) as detector:
|
||||||
|
# Performs language detection on the input.
|
||||||
|
text_result = detector.detect(_MIXED_TEXT)
|
||||||
|
# Comparing results.
|
||||||
|
expected_result = LanguageDetectorResult(
|
||||||
|
[LanguageDetectorPrediction("zh", 0.505424)]
|
||||||
|
)
|
||||||
|
self._expect_language_detector_result_correct(
|
||||||
|
text_result, expected_result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
absltest.main()
|
|
@ -185,3 +185,20 @@ py_test(
|
||||||
"@com_google_protobuf//:protobuf_python",
|
"@com_google_protobuf//:protobuf_python",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "face_aligner_test",
|
||||||
|
srcs = ["face_aligner_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/vision:face_aligner",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
190
mediapipe/tasks/python/test/vision/face_aligner_test.py
Normal file
190
mediapipe/tasks/python/test/vision/face_aligner_test.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for face aligner."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.vision import face_aligner
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_Rect = rect.Rect
|
||||||
|
_Image = image_module.Image
|
||||||
|
_FaceAligner = face_aligner.FaceAligner
|
||||||
|
_FaceAlignerOptions = face_aligner.FaceAlignerOptions
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
|
_MODEL = 'face_landmarker_v2.task'
|
||||||
|
_LARGE_FACE_IMAGE = 'portrait.jpg'
|
||||||
|
_MODEL_IMAGE_SIZE = 256
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAlignerTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _MODEL)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _FaceAligner.create_from_model_path(self.model_path) as aligner:
|
||||||
|
self.assertIsInstance(aligner, _FaceAligner)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
with _FaceAligner.create_from_options(options) as aligner:
|
||||||
|
self.assertIsInstance(aligner, _FaceAligner)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||||
|
):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path='/path/to/invalid/model.tflite'
|
||||||
|
)
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
_FaceAligner.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
aligner = _FaceAligner.create_from_options(options)
|
||||||
|
self.assertIsInstance(aligner, _FaceAligner)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||||
|
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||||
|
)
|
||||||
|
def test_align(self, model_file_type, image_file_name):
|
||||||
|
# Load the test image.
|
||||||
|
self.test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Creates aligner.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
aligner = _FaceAligner.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs face alignment on the input.
|
||||||
|
alignd_image = aligner.align(self.test_image)
|
||||||
|
self.assertIsInstance(alignd_image, _Image)
|
||||||
|
# Closes the aligner explicitly when the aligner is not used in
|
||||||
|
# a context.
|
||||||
|
aligner.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||||
|
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||||
|
)
|
||||||
|
def test_align_in_context(self, model_file_type, image_file_name):
|
||||||
|
# Load the test image.
|
||||||
|
self.test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Creates aligner.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
with _FaceAligner.create_from_options(options) as aligner:
|
||||||
|
# Performs face alignment on the input.
|
||||||
|
alignd_image = aligner.align(self.test_image)
|
||||||
|
self.assertIsInstance(alignd_image, _Image)
|
||||||
|
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
|
||||||
|
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
|
||||||
|
|
||||||
|
def test_align_succeeds_with_region_of_interest(self):
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
with _FaceAligner.create_from_options(options) as aligner:
|
||||||
|
# Load the test image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Region-of-interest around the face.
|
||||||
|
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
|
||||||
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
|
# Performs face alignment on the input.
|
||||||
|
alignd_image = aligner.align(test_image, image_processing_options)
|
||||||
|
self.assertIsInstance(alignd_image, _Image)
|
||||||
|
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
|
||||||
|
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
|
||||||
|
|
||||||
|
def test_align_succeeds_with_no_face_detected(self):
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _FaceAlignerOptions(base_options=base_options)
|
||||||
|
with _FaceAligner.create_from_options(options) as aligner:
|
||||||
|
# Load the test image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Region-of-interest that doesn't contain a human face.
|
||||||
|
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
|
||||||
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
|
# Performs face alignment on the input.
|
||||||
|
alignd_image = aligner.align(test_image, image_processing_options)
|
||||||
|
self.assertIsNone(alignd_image)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -57,3 +57,22 @@ py_library(
|
||||||
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "language_detector",
|
||||||
|
srcs = [
|
||||||
|
"language_detector.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -14,9 +14,13 @@
|
||||||
|
|
||||||
"""MediaPipe Tasks Text API."""
|
"""MediaPipe Tasks Text API."""
|
||||||
|
|
||||||
|
import mediapipe.tasks.python.text.language_detector
|
||||||
import mediapipe.tasks.python.text.text_classifier
|
import mediapipe.tasks.python.text.text_classifier
|
||||||
import mediapipe.tasks.python.text.text_embedder
|
import mediapipe.tasks.python.text.text_embedder
|
||||||
|
|
||||||
|
LanguageDetector = language_detector.LanguageDetector
|
||||||
|
LanguageDetectorOptions = language_detector.LanguageDetectorOptions
|
||||||
|
LanguageDetectorResult = language_detector.LanguageDetectorResult
|
||||||
TextClassifier = text_classifier.TextClassifier
|
TextClassifier = text_classifier.TextClassifier
|
||||||
TextClassifierOptions = text_classifier.TextClassifierOptions
|
TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||||
TextClassifierResult = text_classifier.TextClassifierResult
|
TextClassifierResult = text_classifier.TextClassifierResult
|
||||||
|
@ -26,5 +30,6 @@ TextEmbedderResult = text_embedder.TextEmbedderResult
|
||||||
|
|
||||||
# Remove unnecessary modules to avoid duplication in API docs.
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
del mediapipe
|
del mediapipe
|
||||||
|
del language_detector
|
||||||
del text_classifier
|
del text_classifier
|
||||||
del text_embedder
|
del text_embedder
|
||||||
|
|
220
mediapipe/tasks/python/text/language_detector.py
Normal file
220
mediapipe/tasks/python/text/language_detector.py
Normal file
|
@ -0,0 +1,220 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe language detector task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||||
|
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
|
_ClassificationResult = classification_result_module.ClassificationResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_TextClassifierGraphOptionsProto = (
|
||||||
|
text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||||
|
)
|
||||||
|
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||||
|
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||||
|
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||||
|
_TEXT_TAG = 'TEXT'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LanguageDetectorResult:
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Detection:
|
||||||
|
"""A language code and its probability."""
|
||||||
|
|
||||||
|
# An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||||
|
# "ja"-Latn for Japanese (romaji).
|
||||||
|
language_code: str
|
||||||
|
probability: float
|
||||||
|
|
||||||
|
detections: List[Detection]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_language_detector_result(
|
||||||
|
classification_result: classification_result_module.ClassificationResult,
|
||||||
|
) -> LanguageDetectorResult:
|
||||||
|
"""Extracts a LanguageDetectorResult from a ClassificationResult."""
|
||||||
|
if len(classification_result.classifications) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'The LanguageDetector TextClassifierGraph should have exactly one '
|
||||||
|
'classification head.'
|
||||||
|
)
|
||||||
|
languages_and_scores = classification_result.classifications[0]
|
||||||
|
language_detector_result = LanguageDetectorResult([])
|
||||||
|
for category in languages_and_scores.categories:
|
||||||
|
if category.category_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
'LanguageDetector ClassificationResult has a missing language code.'
|
||||||
|
)
|
||||||
|
prediction = LanguageDetectorResult.Detection(
|
||||||
|
category.category_name, category.score
|
||||||
|
)
|
||||||
|
language_detector_result.detections.append(prediction)
|
||||||
|
return language_detector_result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LanguageDetectorOptions:
|
||||||
|
"""Options for the language detector task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the language detector task.
|
||||||
|
display_names_locale: The locale to use for display names specified through
|
||||||
|
the TFLite Model Metadata.
|
||||||
|
max_results: The maximum number of top-scored classification results to
|
||||||
|
return.
|
||||||
|
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||||
|
below this value are rejected.
|
||||||
|
category_allowlist: Allowlist of category names. If non-empty,
|
||||||
|
classification results whose category name is not in this set will be
|
||||||
|
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||||
|
exclusive with `category_denylist`.
|
||||||
|
category_denylist: Denylist of category names. If non-empty, classification
|
||||||
|
results whose category name is in this set will be filtered out. Duplicate
|
||||||
|
or unknown category names are ignored. Mutually exclusive with
|
||||||
|
`category_allowlist`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_options: _BaseOptions
|
||||||
|
display_names_locale: Optional[str] = None
|
||||||
|
max_results: Optional[int] = None
|
||||||
|
score_threshold: Optional[float] = None
|
||||||
|
category_allowlist: Optional[List[str]] = None
|
||||||
|
category_denylist: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
|
||||||
|
"""Generates an TextClassifierOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
classifier_options_proto = _ClassifierOptionsProto(
|
||||||
|
score_threshold=self.score_threshold,
|
||||||
|
category_allowlist=self.category_allowlist,
|
||||||
|
category_denylist=self.category_denylist,
|
||||||
|
display_names_locale=self.display_names_locale,
|
||||||
|
max_results=self.max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _TextClassifierGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
classifier_options=classifier_options_proto,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageDetector(base_text_task_api.BaseTextTaskApi):
|
||||||
|
"""Class that predicts the language of an input text.
|
||||||
|
|
||||||
|
This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||||
|
mandatory (described below) input tensors, output tensor, and the language
|
||||||
|
codes in an AssociatedFile.
|
||||||
|
|
||||||
|
Input tensors:
|
||||||
|
(kTfLiteString)
|
||||||
|
- 1 input tensor that is scalar or has shape [1] containing the input
|
||||||
|
string.
|
||||||
|
Output tensor:
|
||||||
|
(kTfLiteFloat32)
|
||||||
|
- 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'LanguageDetector':
|
||||||
|
"""Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`LanguageDetector` object that's created from the model file and the
|
||||||
|
default `LanguageDetectorOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `LanguageDetector` object from the
|
||||||
|
provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = LanguageDetectorOptions(base_options=base_options)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(
|
||||||
|
cls, options: LanguageDetectorOptions
|
||||||
|
) -> 'LanguageDetector':
|
||||||
|
"""Creates the `LanguageDetector` object from language detector options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the language detector task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`LanguageDetector` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `LanguageDetector` object from
|
||||||
|
`LanguageDetectorOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
|
||||||
|
],
|
||||||
|
task_options=options,
|
||||||
|
)
|
||||||
|
return cls(task_info.generate_graph_config())
|
||||||
|
|
||||||
|
def detect(self, text: str) -> LanguageDetectorResult:
|
||||||
|
"""Predicts the language of the input `text`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `LanguageDetectorResult` object that contains a list of languages and
|
||||||
|
scores.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If language detection failed to run.
|
||||||
|
"""
|
||||||
|
output_packets = self._runner.process(
|
||||||
|
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}
|
||||||
|
)
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
|
||||||
|
)
|
||||||
|
|
||||||
|
classification_result = _ClassificationResult.create_from_pb2(
|
||||||
|
classification_result_proto
|
||||||
|
)
|
||||||
|
return _extract_language_detector_result(classification_result)
|
|
@ -264,3 +264,22 @@ py_library(
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "face_aligner",
|
||||||
|
srcs = [
|
||||||
|
"face_aligner.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""MediaPipe Tasks Vision API."""
|
"""MediaPipe Tasks Vision API."""
|
||||||
|
|
||||||
import mediapipe.tasks.python.vision.core
|
import mediapipe.tasks.python.vision.core
|
||||||
|
import mediapipe.tasks.python.vision.face_aligner
|
||||||
import mediapipe.tasks.python.vision.face_detector
|
import mediapipe.tasks.python.vision.face_detector
|
||||||
import mediapipe.tasks.python.vision.face_landmarker
|
import mediapipe.tasks.python.vision.face_landmarker
|
||||||
import mediapipe.tasks.python.vision.face_stylizer
|
import mediapipe.tasks.python.vision.face_stylizer
|
||||||
|
@ -25,7 +26,10 @@ import mediapipe.tasks.python.vision.image_embedder
|
||||||
import mediapipe.tasks.python.vision.image_segmenter
|
import mediapipe.tasks.python.vision.image_segmenter
|
||||||
import mediapipe.tasks.python.vision.interactive_segmenter
|
import mediapipe.tasks.python.vision.interactive_segmenter
|
||||||
import mediapipe.tasks.python.vision.object_detector
|
import mediapipe.tasks.python.vision.object_detector
|
||||||
|
import mediapipe.tasks.python.vision.pose_landmarker
|
||||||
|
|
||||||
|
FaceAligner = face_aligner.FaceAligner
|
||||||
|
FaceAlignerOptions = face_aligner.FaceAlignerOptions
|
||||||
FaceDetector = face_detector.FaceDetector
|
FaceDetector = face_detector.FaceDetector
|
||||||
FaceDetectorOptions = face_detector.FaceDetectorOptions
|
FaceDetectorOptions = face_detector.FaceDetectorOptions
|
||||||
FaceDetectorResult = face_detector.FaceDetectorResult
|
FaceDetectorResult = face_detector.FaceDetectorResult
|
||||||
|
@ -41,6 +45,7 @@ GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
|
||||||
HandLandmarker = hand_landmarker.HandLandmarker
|
HandLandmarker = hand_landmarker.HandLandmarker
|
||||||
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||||
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||||
|
HandLandmarksConnections = hand_landmarker.HandLandmarksConnections
|
||||||
ImageClassifier = image_classifier.ImageClassifier
|
ImageClassifier = image_classifier.ImageClassifier
|
||||||
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
ImageClassifierResult = image_classifier.ImageClassifierResult
|
ImageClassifierResult = image_classifier.ImageClassifierResult
|
||||||
|
@ -54,10 +59,16 @@ InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
||||||
InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest
|
InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest
|
||||||
ObjectDetector = object_detector.ObjectDetector
|
ObjectDetector = object_detector.ObjectDetector
|
||||||
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||||
|
ObjectDetectorResult = object_detector.ObjectDetectorResult
|
||||||
|
PoseLandmarker = pose_landmarker.PoseLandmarker
|
||||||
|
PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions
|
||||||
|
PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult
|
||||||
|
PoseLandmarksConnections = pose_landmarker.PoseLandmarksConnections
|
||||||
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
|
||||||
# Remove unnecessary modules to avoid duplication in API docs.
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
del core
|
del core
|
||||||
|
del face_aligner
|
||||||
del face_detector
|
del face_detector
|
||||||
del face_landmarker
|
del face_landmarker
|
||||||
del face_stylizer
|
del face_stylizer
|
||||||
|
@ -68,4 +79,5 @@ del image_embedder
|
||||||
del image_segmenter
|
del image_segmenter
|
||||||
del interactive_segmenter
|
del interactive_segmenter
|
||||||
del object_detector
|
del object_detector
|
||||||
|
del pose_landmarker
|
||||||
del mediapipe
|
del mediapipe
|
||||||
|
|
158
mediapipe/tasks/python/vision/face_aligner.py
Normal file
158
mediapipe/tasks/python/vision/face_aligner.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe face aligner task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_FaceStylizerGraphOptionsProto = (
|
||||||
|
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
|
||||||
|
)
|
||||||
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_FACE_ALIGNMENT_IMAGE_NAME = 'face_alignment'
|
||||||
|
_FACE_ALIGNMENT_IMAGE_TAG = 'FACE_ALIGNMENT'
|
||||||
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class FaceAlignerOptions:
|
||||||
|
"""Options for the face aligner task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the face aligner task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_options: _BaseOptions
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
|
||||||
|
"""Generates a FaceStylizerOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = False
|
||||||
|
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAligner(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
"""Class that performs face alignment on images."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'FaceAligner':
|
||||||
|
"""Creates a `FaceAligner` object from a face landmarker task bundle and the default `FaceAlignerOptions`.
|
||||||
|
|
||||||
|
Note that the created `FaceAligner` instance is in image mode, for
|
||||||
|
aligning one face on a single image input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the face landmarker task bundle.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`FaceAligner` object that's created from the model file and the default
|
||||||
|
`FaceAlignerOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `FaceAligner` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = FaceAlignerOptions(base_options=base_options)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls, options: FaceAlignerOptions) -> 'FaceAligner':
|
||||||
|
"""Creates the `FaceAligner` object from face aligner options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the face aligner task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`FaceAligner` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `FaceAligner` object from
|
||||||
|
`FaceAlignerOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
|
],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_FACE_ALIGNMENT_IMAGE_TAG, _FACE_ALIGNMENT_IMAGE_NAME]),
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||||
|
],
|
||||||
|
task_options=options,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
task_info.generate_graph_config(enable_flow_limiting=False),
|
||||||
|
_RunningMode.IMAGE,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def align(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||||
|
) -> image_module.Image:
|
||||||
|
"""Performs face alignment on the provided MediaPipe Image.
|
||||||
|
|
||||||
|
Only use this method when the FaceAligner is created with the image
|
||||||
|
running mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The aligned face image. The aligned output image size is the same as the
|
||||||
|
model output size. None if no face is detected on the input image.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If face alignment failed to run.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, image
|
||||||
|
)
|
||||||
|
output_packets = self._process_image_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||||
|
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||||
|
normalized_rect.to_pb2()
|
||||||
|
),
|
||||||
|
})
|
||||||
|
if output_packets[_FACE_ALIGNMENT_IMAGE_NAME].is_empty():
|
||||||
|
return None
|
||||||
|
return packet_getter.get_image(output_packets[_FACE_ALIGNMENT_IMAGE_NAME])
|
|
@ -2939,7 +2939,7 @@ class FaceLandmarkerOptions:
|
||||||
Attributes:
|
Attributes:
|
||||||
base_options: Base options for the face landmarker task.
|
base_options: Base options for the face landmarker task.
|
||||||
running_mode: The running mode of the task. Default to the image mode.
|
running_mode: The running mode of the task. Default to the image mode.
|
||||||
HandLandmarker has three running modes: 1) The image mode for detecting
|
FaceLandmarker has three running modes: 1) The image mode for detecting
|
||||||
face landmarks on single image inputs. 2) The video mode for detecting
|
face landmarks on single image inputs. 2) The video mode for detecting
|
||||||
face landmarks on the decoded frames of a video. 3) The live stream mode
|
face landmarks on the decoded frames of a video. 3) The live stream mode
|
||||||
for detecting face landmarks on the live stream of input data, such as
|
for detecting face landmarks on the live stream of input data, such as
|
||||||
|
|
|
@ -82,6 +82,65 @@ class HandLandmark(enum.IntEnum):
|
||||||
PINKY_TIP = 20
|
PINKY_TIP = 20
|
||||||
|
|
||||||
|
|
||||||
|
class HandLandmarksConnections:
|
||||||
|
"""The connections between hand landmarks."""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Connection:
|
||||||
|
"""The connection class for hand landmarks."""
|
||||||
|
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
HAND_PALM_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(0, 1),
|
||||||
|
Connection(1, 5),
|
||||||
|
Connection(9, 13),
|
||||||
|
Connection(13, 17),
|
||||||
|
Connection(5, 9),
|
||||||
|
Connection(0, 17),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_THUMB_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(1, 2),
|
||||||
|
Connection(2, 3),
|
||||||
|
Connection(3, 4),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_INDEX_FINGER_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(5, 6),
|
||||||
|
Connection(6, 7),
|
||||||
|
Connection(7, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_MIDDLE_FINGER_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(9, 10),
|
||||||
|
Connection(10, 11),
|
||||||
|
Connection(11, 12),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_RING_FINGER_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(13, 14),
|
||||||
|
Connection(14, 15),
|
||||||
|
Connection(15, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_PINKY_FINGER_CONNECTIONS: List[Connection] = [
|
||||||
|
Connection(17, 18),
|
||||||
|
Connection(18, 19),
|
||||||
|
Connection(19, 20),
|
||||||
|
]
|
||||||
|
|
||||||
|
HAND_CONNECTIONS: List[Connection] = (
|
||||||
|
HAND_PALM_CONNECTIONS +
|
||||||
|
HAND_THUMB_CONNECTIONS +
|
||||||
|
HAND_INDEX_FINGER_CONNECTIONS +
|
||||||
|
HAND_MIDDLE_FINGER_CONNECTIONS +
|
||||||
|
HAND_RING_FINGER_CONNECTIONS +
|
||||||
|
HAND_PINKY_FINGER_CONNECTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class HandLandmarkerResult:
|
class HandLandmarkerResult:
|
||||||
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
|
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
|
||||||
|
|
|
@ -88,7 +88,7 @@ class InteractiveSegmenterOptions:
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||||
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
"""Generates an ImageSegmenterGraphOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False
|
base_options_proto.use_stream_mode = False
|
||||||
segmenter_options_proto = _SegmenterOptionsProto()
|
segmenter_options_proto = _SegmenterOptionsProto()
|
||||||
|
|
|
@ -132,6 +132,55 @@ def _build_landmarker_result(
|
||||||
return pose_landmarker_result
|
return pose_landmarker_result
|
||||||
|
|
||||||
|
|
||||||
|
class PoseLandmarksConnections:
|
||||||
|
"""The connections between pose landmarks."""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Connection:
|
||||||
|
"""The connection class for pose landmarks."""
|
||||||
|
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
POSE_LANDMARKS: List[Connection] = [
|
||||||
|
Connection(0, 1),
|
||||||
|
Connection(1, 2),
|
||||||
|
Connection(2, 3),
|
||||||
|
Connection(3, 7),
|
||||||
|
Connection(0, 4),
|
||||||
|
Connection(4, 5),
|
||||||
|
Connection(5, 6),
|
||||||
|
Connection(6, 8),
|
||||||
|
Connection(9, 10),
|
||||||
|
Connection(11, 12),
|
||||||
|
Connection(11, 13),
|
||||||
|
Connection(13, 15),
|
||||||
|
Connection(15, 17),
|
||||||
|
Connection(15, 19),
|
||||||
|
Connection(15, 21),
|
||||||
|
Connection(17, 19),
|
||||||
|
Connection(12, 14),
|
||||||
|
Connection(14, 16),
|
||||||
|
Connection(16, 18),
|
||||||
|
Connection(16, 20),
|
||||||
|
Connection(16, 22),
|
||||||
|
Connection(18, 20),
|
||||||
|
Connection(11, 23),
|
||||||
|
Connection(12, 24),
|
||||||
|
Connection(23, 24),
|
||||||
|
Connection(23, 25),
|
||||||
|
Connection(24, 26),
|
||||||
|
Connection(25, 27),
|
||||||
|
Connection(26, 28),
|
||||||
|
Connection(27, 29),
|
||||||
|
Connection(28, 30),
|
||||||
|
Connection(29, 31),
|
||||||
|
Connection(30, 32),
|
||||||
|
Connection(27, 31),
|
||||||
|
Connection(28, 32)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PoseLandmarkerOptions:
|
class PoseLandmarkerOptions:
|
||||||
"""Options for the pose landmarker task.
|
"""Options for the pose landmarker task.
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/vector.h"
|
#include "mediapipe/framework/port/vector.h"
|
||||||
#include "mediapipe/util/color.pb.h"
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
#include "mediapipe/util/render_data.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
|
||||||
DrawGradientLine(annotation);
|
DrawGradientLine(annotation);
|
||||||
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
||||||
DrawArrow(annotation);
|
DrawArrow(annotation);
|
||||||
|
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
|
||||||
|
DrawScribble(annotation);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
||||||
}
|
}
|
||||||
|
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||||
const auto& point = annotation.point();
|
DrawPoint(annotation.point(), annotation);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
|
||||||
|
const RenderAnnotation& annotation) {
|
||||||
int x = -1;
|
int x = -1;
|
||||||
int y = -1;
|
int y = -1;
|
||||||
if (point.normalized()) {
|
if (point.normalized()) {
|
||||||
|
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||||
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
|
||||||
|
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
|
||||||
|
DrawPoint(point, annotation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
||||||
int x_start = -1;
|
int x_start = -1;
|
||||||
int y_start = -1;
|
int y_start = -1;
|
||||||
|
|
|
@ -96,6 +96,11 @@ class AnnotationRenderer {
|
||||||
|
|
||||||
// Draws a point on the image as described in the annotation.
|
// Draws a point on the image as described in the annotation.
|
||||||
void DrawPoint(const RenderAnnotation& annotation);
|
void DrawPoint(const RenderAnnotation& annotation);
|
||||||
|
void DrawPoint(const RenderAnnotation::Point& point,
|
||||||
|
const RenderAnnotation& annotation);
|
||||||
|
|
||||||
|
// Draws scribbles on the image as described in the annotation.
|
||||||
|
void DrawScribble(const RenderAnnotation& annotation);
|
||||||
|
|
||||||
// Draws a line segment on the image as described in the annotation.
|
// Draws a line segment on the image as described in the annotation.
|
||||||
void DrawLine(const RenderAnnotation& annotation);
|
void DrawLine(const RenderAnnotation& annotation);
|
||||||
|
|
|
@ -131,6 +131,10 @@ message RenderAnnotation {
|
||||||
optional Color color2 = 7;
|
optional Color color2 = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Scribble {
|
||||||
|
repeated Point point = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Arrow {
|
message Arrow {
|
||||||
// The arrow head will be drawn at (x_end, y_end).
|
// The arrow head will be drawn at (x_end, y_end).
|
||||||
optional double x_start = 1;
|
optional double x_start = 1;
|
||||||
|
@ -192,6 +196,7 @@ message RenderAnnotation {
|
||||||
RoundedRectangle rounded_rectangle = 9;
|
RoundedRectangle rounded_rectangle = 9;
|
||||||
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
||||||
GradientLine gradient_line = 14;
|
GradientLine gradient_line = 14;
|
||||||
|
Scribble scribble = 15;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Thickness for drawing the annotation.
|
// Thickness for drawing the annotation.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user