Merge branch 'ios-normalized-keypoint-hash' into ios-async-calls-fixes
This commit is contained in:
commit
db5fd168b6
|
@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input,
|
|||
static_cast<int>(range_max) == 255);
|
||||
}
|
||||
|
||||
auto input_frame = input.GetGpuBuffer().GetReadView<FrameBuffer>();
|
||||
auto input_frame =
|
||||
input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView<FrameBuffer>();
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2],
|
||||
|
|
|
@ -1285,12 +1285,14 @@ cc_library(
|
|||
srcs = ["flat_color_image_calculator.cc"],
|
||||
deps = [
|
||||
":flat_color_image_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util:color_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -15,14 +15,13 @@
|
|||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -32,6 +31,7 @@ namespace {
|
|||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Node;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::SideOutput;
|
||||
} // namespace
|
||||
|
||||
// A calculator for generating an image filled with a single color.
|
||||
|
@ -45,7 +45,8 @@ using ::mediapipe::api2::Output;
|
|||
//
|
||||
// Outputs:
|
||||
// IMAGE (Image)
|
||||
// Image filled with the requested color.
|
||||
// Image filled with the requested color. Can be either an output_stream
|
||||
// or an output_side_packet.
|
||||
//
|
||||
// Example useage:
|
||||
// node {
|
||||
|
@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node {
|
|||
public:
|
||||
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
||||
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
||||
static constexpr Output<Image> kOutImage{"IMAGE"};
|
||||
static constexpr Output<Image>::Optional kOutImage{"IMAGE"};
|
||||
static constexpr SideOutput<Image>::Optional kOutSideImage{"IMAGE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
|
||||
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||
|
@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node {
|
|||
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
||||
<< "Either set COLOR input stream, or set through options";
|
||||
|
||||
RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected())
|
||||
<< "Set IMAGE either as output stream, or as output side packet";
|
||||
|
||||
RET_CHECK(!kOutSideImage(cc).IsConnected() ||
|
||||
(options.has_output_height() && options.has_output_width()))
|
||||
<< "Set size through options, when setting IMAGE as output side packet";
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node {
|
|||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::optional<std::shared_ptr<ImageFrame>> CreateOutputFrame(
|
||||
CalculatorContext* cc);
|
||||
|
||||
bool use_dimension_from_option_ = false;
|
||||
bool use_color_from_option_ = false;
|
||||
};
|
||||
|
@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
|
|||
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
||||
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
||||
use_color_from_option_ = !kInColor(cc).IsConnected();
|
||||
|
||||
if (!kOutImage(cc).IsConnected()) {
|
||||
std::optional<std::shared_ptr<ImageFrame>> output_frame =
|
||||
CreateOutputFrame(cc);
|
||||
if (output_frame.has_value()) {
|
||||
kOutSideImage(cc).Set(Image(output_frame.value()));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||
if (kOutImage(cc).IsConnected()) {
|
||||
std::optional<std::shared_ptr<ImageFrame>> output_frame =
|
||||
CreateOutputFrame(cc);
|
||||
if (output_frame.has_value()) {
|
||||
kOutImage(cc).Send(Image(output_frame.value()));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::optional<std::shared_ptr<ImageFrame>>
|
||||
FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||
|
||||
int output_height = -1;
|
||||
|
@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
|||
output_height = input_image.height();
|
||||
output_width = input_image.width();
|
||||
} else {
|
||||
return absl::OkStatus();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Color color;
|
||||
|
@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
|||
} else if (!kInColor(cc).IsEmpty()) {
|
||||
color = kInColor(cc).Get();
|
||||
} else {
|
||||
return absl::OkStatus();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
|
@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
||||
|
||||
kOutImage(cc).Send(Image(output_frame));
|
||||
|
||||
return absl::OkStatus();
|
||||
return output_frame;
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
output_side_packet: "IMAGE:out_packet"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
output_width: 1
|
||||
output_height: 1
|
||||
color: {
|
||||
r: 100,
|
||||
g: 200,
|
||||
b: 255,
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get<Image>();
|
||||
EXPECT_EQ(image.width(), 1);
|
||||
EXPECT_EQ(image.height(), 1);
|
||||
auto image_frame = image.GetImageFrameSharedPtr();
|
||||
const uint8_t* pixel_data = image_frame->PixelData();
|
||||
EXPECT_EQ(pixel_data[0], 100);
|
||||
EXPECT_EQ(pixel_data[1], 200);
|
||||
EXPECT_EQ(pixel_data[2], 255);
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
|
@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
|
|||
HasSubstr("Either set COLOR input stream"));
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
output_stream: "IMAGE:out_image"
|
||||
output_side_packet: "IMAGE:out_packet"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
output_width: 1
|
||||
output_height: 1
|
||||
color: {
|
||||
r: 100,
|
||||
g: 200,
|
||||
b: 255,
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
ASSERT_THAT(
|
||||
runner.Run().message(),
|
||||
HasSubstr("Set IMAGE either as output stream, or as output side packet"));
|
||||
}
|
||||
|
||||
TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) {
|
||||
CalculatorRunner runner(R"pb(
|
||||
calculator: "FlatColorImageCalculator"
|
||||
input_stream: "IMAGE:image"
|
||||
output_side_packet: "IMAGE:out_packet"
|
||||
options {
|
||||
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||
color: {
|
||||
r: 100,
|
||||
g: 200,
|
||||
b: 255,
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||
kImageWidth, kImageHeight);
|
||||
|
||||
for (int ts = 0; ts < 3; ++ts) {
|
||||
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||
}
|
||||
ASSERT_THAT(runner.Run().message(),
|
||||
HasSubstr("Set size through options, when setting IMAGE as "
|
||||
"output side packet"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) {
|
|||
double target_aspect_ratio = 0.5;
|
||||
int expect_width = 14;
|
||||
int expect_height = input_height;
|
||||
auto test_frame = absl::make_unique<ImageFrame>(/*format=*/ImageFormat::SRGB,
|
||||
input_width, input_height);
|
||||
ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width,
|
||||
input_height);
|
||||
cv::Mat mat = formats::MatView(&test_frame);
|
||||
mat = cv::Scalar(0, 0, 0);
|
||||
|
||||
PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(),
|
||||
PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(),
|
||||
target_aspect_ratio,
|
||||
/*scale_to_multiple_of_two=*/true);
|
||||
ImageFrame result_frame;
|
||||
MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame));
|
||||
MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame));
|
||||
EXPECT_EQ(result_frame.Width(), expect_width);
|
||||
EXPECT_EQ(result_frame.Height(), expect_height);
|
||||
}
|
||||
|
|
|
@ -113,11 +113,11 @@ class Image {
|
|||
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// Get a GPU view. Automatically uploads from CPU if needed.
|
||||
const mediapipe::GpuBuffer GetGpuBuffer() const {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
if (use_gpu_ == false) ConvertToGpu();
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// Provides access to the underlying GpuBuffer storage.
|
||||
// Automatically uploads from CPU to GPU if needed and requested through the
|
||||
// `upload_to_gpu` argument.
|
||||
const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const {
|
||||
if (!use_gpu_ && upload_to_gpu) ConvertToGpu();
|
||||
return gpu_buffer_;
|
||||
}
|
||||
|
||||
|
|
31
mediapipe/framework/port/drishti_proto_alias_rules.bzl
Normal file
31
mediapipe/framework/port/drishti_proto_alias_rules.bzl
Normal file
|
@ -0,0 +1,31 @@
|
|||
"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly."""
|
||||
|
||||
def _copy_header_impl(ctx):
|
||||
source = ctx.attr.source.replace("//", "").replace(":", "/")
|
||||
files = []
|
||||
for dep in ctx.attr.deps:
|
||||
for header in dep[CcInfo].compilation_context.direct_headers:
|
||||
if (header.short_path == source):
|
||||
files.append(header)
|
||||
if len(files) != 1:
|
||||
fail("Expected exactly 1 source, got ", str(files))
|
||||
dest_file = ctx.actions.declare_file(ctx.attr.filename)
|
||||
|
||||
# Use expand_template() with no substitutions as a simple copier.
|
||||
ctx.actions.expand_template(
|
||||
template = files[0],
|
||||
output = dest_file,
|
||||
substitutions = {},
|
||||
)
|
||||
return [DefaultInfo(files = depset([dest_file]))]
|
||||
|
||||
copy_header = rule(
|
||||
implementation = _copy_header_impl,
|
||||
attrs = {
|
||||
"filename": attr.string(),
|
||||
"source": attr.string(),
|
||||
"deps": attr.label_list(providers = [CcInfo]),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
outputs = {"out": "%{filename}"},
|
||||
)
|
|
@ -791,6 +791,7 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@stblib//:stb_image",
|
||||
"@stblib//:stb_image_write",
|
||||
],
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
|
@ -311,6 +312,13 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
|
|||
// Returns the path to the output if successful.
|
||||
absl::StatusOr<std::string> SavePngTestOutput(
|
||||
const mediapipe::ImageFrame& image, absl::string_view prefix) {
|
||||
absl::flat_hash_set<ImageFormat::Format> supported_formats = {
|
||||
ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA,
|
||||
ImageFormat::LAB8, ImageFormat::SBGRA};
|
||||
if (!supported_formats.contains(image.Format())) {
|
||||
return absl::CancelledError(
|
||||
absl::StrFormat("Format %d can not be saved to PNG.", image.Format()));
|
||||
}
|
||||
std::string now_string = absl::FormatTime(absl::Now());
|
||||
std::string output_relative_path =
|
||||
absl::StrCat(prefix, "_", now_string, ".png");
|
||||
|
|
|
@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
self._num_classes = num_classes
|
||||
self._model = self._build_model()
|
||||
checkpoint_folder = self._model_spec.downloaded_files.get_path()
|
||||
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
|
||||
checkpoint_file = os.path.join(
|
||||
checkpoint_folder, self._model_spec.checkpoint_name
|
||||
)
|
||||
self.load_checkpoint(checkpoint_file)
|
||||
self._model.summary()
|
||||
self.loss_trackers = [
|
||||
|
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
|
|||
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
|
||||
),
|
||||
backbone=configs.backbones.Backbone(
|
||||
type='mobilenet', mobilenet=configs.backbones.MobileNet()
|
||||
type='mobilenet',
|
||||
mobilenet=configs.backbones.MobileNet(
|
||||
model_id=self._model_spec.model_id
|
||||
),
|
||||
),
|
||||
decoder=configs.decoders.Decoder(
|
||||
type='fpn',
|
||||
|
|
|
@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
|
|||
is_folder=True,
|
||||
)
|
||||
|
||||
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
|
||||
'object_detector/mobilenetmultiavg',
|
||||
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelSpec(object):
|
||||
|
@ -38,13 +44,25 @@ class ModelSpec(object):
|
|||
stddev_rgb = (127.5,)
|
||||
|
||||
downloaded_files: file_util.DownloadedFiles
|
||||
checkpoint_name: str
|
||||
input_image_shape: List[int]
|
||||
model_id: str
|
||||
|
||||
|
||||
mobilenet_v2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_V2_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetV2',
|
||||
)
|
||||
|
||||
mobilenet_multi_avg_spec = functools.partial(
|
||||
ModelSpec,
|
||||
downloaded_files=MOBILENET_MULTI_AVG_FILES,
|
||||
checkpoint_name='ckpt-277200',
|
||||
input_image_shape=[256, 256, 3],
|
||||
model_id='MobileNetMultiAVG',
|
||||
)
|
||||
|
||||
|
||||
|
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
|
|||
"""Predefined object detector model specs supported by Model Maker."""
|
||||
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
|
||||
|
||||
@classmethod
|
||||
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':
|
||||
|
|
|
@ -93,3 +93,8 @@ mediapipe_proto_library(
|
|||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "transformer_params_proto",
|
||||
srcs = ["transformer_params.proto"],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "TransformerParametersProto";
|
||||
|
||||
// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf)
|
||||
message TransformerParameters {
|
||||
// Batch size of tensors.
|
||||
int32 batch_size = 1;
|
||||
|
||||
// Maximum sequence length of the input/output tensor.
|
||||
int32 max_seq_length = 2;
|
||||
|
||||
// Embedding dimension (or model dimension), `d_model` in the paper.
|
||||
// `d_k` == `d_v` == `d_model`/`h`.
|
||||
int32 embedding_dim = 3;
|
||||
|
||||
// Hidden dimension used in the feedforward layer, `d_ff` in the paper.
|
||||
int32 hidden_dimension = 4;
|
||||
|
||||
// Head dimension, `d_k` or `d_v` in the paper.
|
||||
int32 head_dimension = 5;
|
||||
|
||||
// Number of heads, `h` in the paper.
|
||||
int32 num_heads = 6;
|
||||
|
||||
// Number of stacked transformers, `N` in the paper.
|
||||
int32 num_stacks = 7;
|
||||
}
|
|
@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
|
|||
auto matrix = preprocessing.Out(kMatrixTag);
|
||||
auto image_size = preprocessing.Out(kImageSizeTag);
|
||||
|
||||
// Face detection model inferece.
|
||||
// Face detection model inference.
|
||||
auto& inference = AddInference(
|
||||
model_resources, subgraph_options.base_options().acceleration(), graph);
|
||||
preprocessed_tensors >> inference.In(kTensorsTag);
|
||||
|
|
|
@ -199,7 +199,9 @@ void ConfigureTensorsToImageCalculator(
|
|||
// STYLIZED_IMAGE - mediapipe::Image
|
||||
// The face stylization output image.
|
||||
// FACE_ALIGNMENT - mediapipe::Image
|
||||
// The face alignment output image.
|
||||
// The aligned face image that is fed to the face stylization model to
|
||||
// perform stylization. Also useful for preparing face stylization training
|
||||
// data.
|
||||
// IMAGE - mediapipe::Image
|
||||
// The input image that the face landmarker runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
|
@ -211,6 +213,7 @@ void ConfigureTensorsToImageCalculator(
|
|||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// output_stream: "STYLIZED_IMAGE:stylized_image"
|
||||
// output_stream: "FACE_ALIGNMENT:face_alignment_image"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext]
|
||||
// {
|
||||
|
@ -248,7 +251,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
->mutable_face_landmarker_graph_options(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
const ModelResources* face_stylizer_model_resources;
|
||||
const ModelResources* face_stylizer_model_resources = nullptr;
|
||||
if (output_stylized) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
|
@ -332,7 +335,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
auto face_rect = face_to_rect.Out(kNormRectTag);
|
||||
|
||||
std::optional<Source<Image>> face_alignment;
|
||||
// Output face alignment only.
|
||||
// Output aligned face only.
|
||||
// In this case, the face stylization model inference is not required.
|
||||
// However, to keep consistent with the inference preprocessing steps, the
|
||||
// ImageToTensorCalculator is still used to perform image rotation,
|
||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
|||
constexpr absl::string_view kSubgraphTypeName{
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||
|
||||
using components::containers::NormalizedKeypoint;
|
||||
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
|
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
|||
case RegionOfInterest::Format::kUnspecified:
|
||||
return absl::InvalidArgumentError(
|
||||
"RegionOfInterest format not specified");
|
||||
case RegionOfInterest::Format::kKeyPoint:
|
||||
case RegionOfInterest::Format::kKeyPoint: {
|
||||
RET_CHECK(roi.keypoint.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
|
@ -125,6 +128,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
|
|||
point->set_y(roi.keypoint->y);
|
||||
return result;
|
||||
}
|
||||
case RegionOfInterest::Format::kScribble: {
|
||||
RET_CHECK(roi.scribble.has_value());
|
||||
auto* annotation = result.add_render_annotations();
|
||||
annotation->mutable_color()->set_r(255);
|
||||
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
|
||||
auto* point = annotation->mutable_scribble()->add_point();
|
||||
point->set_normalized(true);
|
||||
point->set_x(keypoint.x);
|
||||
point->set_y(keypoint.y);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return absl::UnimplementedError("Unrecognized format");
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ struct RegionOfInterest {
|
|||
enum class Format {
|
||||
kUnspecified = 0, // Format not specified.
|
||||
kKeyPoint = 1, // Using keypoint to represent ROI.
|
||||
kScribble = 2, // Using scribble to represent ROI.
|
||||
};
|
||||
|
||||
// Specifies the format used to specify the region-of-interest. Note that
|
||||
|
@ -61,8 +62,13 @@ struct RegionOfInterest {
|
|||
Format format = Format::kUnspecified;
|
||||
|
||||
// Represents the ROI in keypoint format, this should be non-nullopt if
|
||||
// `format` is `KEYPOINT`.
|
||||
// `format` is `kKeyPoint`.
|
||||
std::optional<components::containers::NormalizedKeypoint> keypoint;
|
||||
|
||||
// Represents the ROI in scribble format, this should be non-nullopt if
|
||||
// `format` is `kScribble`.
|
||||
std::optional<std::vector<components::containers::NormalizedKeypoint>>
|
||||
scribble;
|
||||
};
|
||||
|
||||
// Performs interactive segmentation on images.
|
||||
|
|
|
@ -18,9 +18,12 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
|
|||
struct InteractiveSegmenterTestParams {
|
||||
std::string test_name;
|
||||
RegionOfInterest::Format format;
|
||||
NormalizedKeypoint roi;
|
||||
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
||||
absl::string_view golden_mask_file;
|
||||
float similarity_threshold;
|
||||
};
|
||||
|
||||
using SucceedSegmentationWithRoi =
|
||||
::testing::TestWithParam<InteractiveSegmenterTestParams>;
|
||||
class SucceedSegmentationWithRoi
|
||||
: public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
|
||||
public:
|
||||
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
switch (params.format) {
|
||||
case (RegionOfInterest::Format::kKeyPoint): {
|
||||
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
|
||||
break;
|
||||
}
|
||||
case (RegionOfInterest::Format::kScribble): {
|
||||
interaction_roi.scribble =
|
||||
std::get<std::vector<NormalizedKeypoint>>(params.roi);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return absl::InvalidArgumentError("Unknown ROI format");
|
||||
}
|
||||
}
|
||||
|
||||
return interaction_roi;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||
TestParamsToTaskOptions());
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
|
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
|||
}
|
||||
|
||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||
const auto& params = GetParam();
|
||||
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
|
||||
TestParamsToTaskOptions());
|
||||
const InteractiveSegmenterTestParams& params = GetParam();
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
||||
RegionOfInterest interaction_roi;
|
||||
interaction_roi.format = params.format;
|
||||
interaction_roi.keypoint = params.roi;
|
||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||
|
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
|||
INSTANTIATE_TEST_SUITE_P(
|
||||
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
|
||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||
{// Keypoint input.
|
||||
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
||||
kGoldenMaskSimilarity}}),
|
||||
kGoldenMaskSimilarity},
|
||||
// Scribble input.
|
||||
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
||||
std::vector{NormalizedKeypoint{0.44, 0.70},
|
||||
NormalizedKeypoint{0.44, 0.71},
|
||||
NormalizedKeypoint{0.44, 0.72}},
|
||||
kCatsAndDogsMaskDog1, 0.84f},
|
||||
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
||||
std::vector{NormalizedKeypoint{0.66, 0.66},
|
||||
NormalizedKeypoint{0.66, 0.67},
|
||||
NormalizedKeypoint{0.66, 0.68}},
|
||||
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||
info) { return info.param.test_name; });
|
||||
|
||||
|
|
|
@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
}
|
||||
if (options->base_options().acceleration().has_gpu()) {
|
||||
core::proto::Acceleration gpu_accel;
|
||||
gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true);
|
||||
pose_detector_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(gpu_accel);
|
||||
|
||||
} else {
|
||||
pose_detector_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
}
|
||||
pose_detector_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
auto* pose_landmarks_detector_graph_options =
|
||||
|
|
|
@ -28,7 +28,12 @@
|
|||
return self;
|
||||
}
|
||||
|
||||
// TODO: Implement hash
|
||||
- (NSUInteger)hash {
|
||||
NSUInteger nonNullPropertiesHash =
|
||||
@(self.location.x).hash ^ @(self.location.y).hash ^ @(self.score).hash;
|
||||
|
||||
return self.label ? nonNullPropertiesHash ^ self.label.hash : nonNullPropertiesHash;
|
||||
}
|
||||
|
||||
- (BOOL)isEqual:(nullable id)object {
|
||||
if (!object) {
|
||||
|
|
|
@ -180,6 +180,7 @@ android_library(
|
|||
srcs = [
|
||||
"poselandmarker/PoseLandmarker.java",
|
||||
"poselandmarker/PoseLandmarkerResult.java",
|
||||
"poselandmarker/PoseLandmarksConnections.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
|
@ -212,6 +213,7 @@ android_library(
|
|||
"handlandmarker/HandLandmark.java",
|
||||
"handlandmarker/HandLandmarker.java",
|
||||
"handlandmarker/HandLandmarkerResult.java",
|
||||
"handlandmarker/HandLandmarksConnections.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
|
|
|
@ -77,11 +77,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
if (!normRectStreamName.isEmpty()) {
|
||||
inputPackets.put(
|
||||
normRectStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||
}
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
|
@ -105,11 +107,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
if (!normRectStreamName.isEmpty()) {
|
||||
inputPackets.put(
|
||||
normRectStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||
}
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
@ -133,11 +137,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
if (!normRectStreamName.isEmpty()) {
|
||||
inputPackets.put(
|
||||
normRectStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
|
||||
}
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/** Hand landmarks connection constants. */
|
||||
public final class HandLandmarksConnections {
|
||||
|
||||
/** Value class representing hand landmarks connection. */
|
||||
@AutoValue
|
||||
public abstract static class Connection {
|
||||
static Connection create(int start, int end) {
|
||||
return new AutoValue_HandLandmarksConnections_Connection(start, end);
|
||||
}
|
||||
|
||||
public abstract int start();
|
||||
|
||||
public abstract int end();
|
||||
}
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_PALM_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(0, 1),
|
||||
Connection.create(0, 5),
|
||||
Connection.create(9, 13),
|
||||
Connection.create(13, 17),
|
||||
Connection.create(5, 9),
|
||||
Connection.create(0, 17))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_THUMB_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(1, 2), Connection.create(2, 3), Connection.create(3, 4))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_INDEX_FINGER_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(5, 6), Connection.create(6, 7), Connection.create(7, 8))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_MIDDLE_FINGER_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(9, 10), Connection.create(10, 11), Connection.create(11, 12))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_RING_FINGER_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(13, 14),
|
||||
Connection.create(14, 15),
|
||||
Connection.create(15, 16))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_PINKY_FINGER_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(17, 18),
|
||||
Connection.create(18, 19),
|
||||
Connection.create(19, 20))));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> HAND_CONNECTIONS =
|
||||
Collections.unmodifiableSet(
|
||||
Stream.of(
|
||||
HAND_PALM_CONNECTIONS.stream(),
|
||||
HAND_THUMB_CONNECTIONS.stream(),
|
||||
HAND_INDEX_FINGER_CONNECTIONS.stream(),
|
||||
HAND_MIDDLE_FINGER_CONNECTIONS.stream(),
|
||||
HAND_RING_FINGER_CONNECTIONS.stream(),
|
||||
HAND_PINKY_FINGER_CONNECTIONS.stream())
|
||||
.flatMap(i -> i)
|
||||
.collect(Collectors.toSet()));
|
||||
|
||||
private HandLandmarksConnections() {}
|
||||
}
|
|
@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
/** The Region-Of-Interest (ROI) to interact with. */
|
||||
public static class RegionOfInterest {
|
||||
private NormalizedKeypoint keypoint;
|
||||
private List<NormalizedKeypoint> scribble;
|
||||
|
||||
private RegionOfInterest() {}
|
||||
|
||||
|
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
roi.keypoint = keypoint;
|
||||
return roi;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
|
||||
* user wants to segment.
|
||||
*/
|
||||
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
|
||||
RegionOfInterest roi = new RegionOfInterest();
|
||||
roi.scribble = scribble;
|
||||
return roi;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
.setX(roi.keypoint.x())
|
||||
.setY(roi.keypoint.y())))
|
||||
.build();
|
||||
} else if (roi.scribble != null) {
|
||||
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
|
||||
for (NormalizedKeypoint p : roi.scribble) {
|
||||
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
|
||||
}
|
||||
|
||||
return builder
|
||||
.addRenderAnnotations(
|
||||
RenderAnnotation.newBuilder()
|
||||
.setColor(Color.newBuilder().setR(255))
|
||||
.setScribble(scribbleBuilder))
|
||||
.build();
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.poselandmarker;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
/** Pose landmarks connection constants. */
|
||||
public final class PoseLandmarksConnections {
|
||||
|
||||
/** Value class representing pose landmarks connection. */
|
||||
@AutoValue
|
||||
public abstract static class Connection {
|
||||
static Connection create(int start, int end) {
|
||||
return new AutoValue_PoseLandmarksConnections_Connection(start, end);
|
||||
}
|
||||
|
||||
public abstract int start();
|
||||
|
||||
public abstract int end();
|
||||
}
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
public static final Set<Connection> POSE_LANDMARKS =
|
||||
Collections.unmodifiableSet(
|
||||
new HashSet<>(
|
||||
Arrays.asList(
|
||||
Connection.create(0, 1),
|
||||
Connection.create(1, 2),
|
||||
Connection.create(2, 3),
|
||||
Connection.create(3, 7),
|
||||
Connection.create(0, 4),
|
||||
Connection.create(4, 5),
|
||||
Connection.create(5, 6),
|
||||
Connection.create(6, 8),
|
||||
Connection.create(9, 10),
|
||||
Connection.create(11, 12),
|
||||
Connection.create(11, 13),
|
||||
Connection.create(13, 15),
|
||||
Connection.create(15, 17),
|
||||
Connection.create(15, 19),
|
||||
Connection.create(15, 21),
|
||||
Connection.create(17, 19),
|
||||
Connection.create(12, 14),
|
||||
Connection.create(14, 16),
|
||||
Connection.create(16, 18),
|
||||
Connection.create(16, 20),
|
||||
Connection.create(16, 22),
|
||||
Connection.create(18, 20),
|
||||
Connection.create(11, 23),
|
||||
Connection.create(12, 24),
|
||||
Connection.create(23, 24),
|
||||
Connection.create(23, 25),
|
||||
Connection.create(24, 26),
|
||||
Connection.create(25, 27),
|
||||
Connection.create(26, 28),
|
||||
Connection.create(27, 29),
|
||||
Connection.create(28, 30),
|
||||
Connection.create(29, 31),
|
||||
Connection.create(30, 32),
|
||||
Connection.create(27, 31),
|
||||
Connection.create(28, 32))));
|
||||
|
||||
private PoseLandmarksConnections() {}
|
||||
}
|
|
@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
|
|||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
||||
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
|
|||
/** Test for {@link InteractiveSegmenter}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({
|
||||
InteractiveSegmenterTest.General.class,
|
||||
InteractiveSegmenterTest.KeypointRoi.class,
|
||||
InteractiveSegmenterTest.ScribbleRoi.class,
|
||||
})
|
||||
public class InteractiveSegmenterTest {
|
||||
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
||||
|
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
|
|||
private static final int MAGNIFICATION_FACTOR = 10;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends InteractiveSegmenterTest {
|
||||
public static final class KeypointRoi extends InteractiveSegmenterTest {
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
|
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
|
|||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class ScribbleRoi extends InteractiveSegmenterTest {
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||
final InteractiveSegmenter.RegionOfInterest roi =
|
||||
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputConfidenceMasks(false)
|
||||
.setOutputCategoryMask(true)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithConfidenceMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||
final InteractiveSegmenter.RegionOfInterest roi =
|
||||
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
ImageSegmenterResult actualResult =
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
|
|
|
@ -39,9 +39,11 @@ _Landmark = landmark_module.Landmark
|
|||
class LandmarksDetectionResult:
|
||||
"""Represents the landmarks detection result.
|
||||
|
||||
Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A
|
||||
list of `Category` objects. world_landmarks : A list of `Landmark` objects.
|
||||
rect : A `NormalizedRect` object.
|
||||
Attributes:
|
||||
landmarks: A list of `NormalizedLandmark` objects.
|
||||
categories: A list of `Category` objects.
|
||||
world_landmarks: A list of `Landmark` objects.
|
||||
rect: A `NormalizedRect` object.
|
||||
"""
|
||||
|
||||
landmarks: Optional[List[_NormalizedLandmark]]
|
||||
|
|
|
@ -49,3 +49,18 @@ py_test(
|
|||
"//mediapipe/tasks/python/text:text_embedder",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "language_detector_test",
|
||||
srcs = ["language_detector_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:language_detector",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/text:language_detector",
|
||||
],
|
||||
)
|
||||
|
|
228
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
228
mediapipe/tasks/python/test/text/language_detector_test.py
Normal file
|
@ -0,0 +1,228 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for language detector."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.python.components.containers import category
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.text import language_detector
|
||||
|
||||
LanguageDetectorResult = language_detector.LanguageDetectorResult
|
||||
LanguageDetectorPrediction = (
|
||||
language_detector.LanguageDetectorResult.Detection
|
||||
)
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category.Category
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_LanguageDetector = language_detector.LanguageDetector
|
||||
_LanguageDetectorOptions = language_detector.LanguageDetectorOptions
|
||||
|
||||
_LANGUAGE_DETECTOR_MODEL = "language_detector.tflite"
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/text"
|
||||
|
||||
_SCORE_THRESHOLD = 0.3
|
||||
_EN_TEXT = "To be, or not to be, that is the question"
|
||||
_EN_EXPECTED_RESULT = LanguageDetectorResult(
|
||||
[LanguageDetectorPrediction("en", 0.999856)]
|
||||
)
|
||||
_FR_TEXT = (
|
||||
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent."
|
||||
)
|
||||
_FR_EXPECTED_RESULT = LanguageDetectorResult(
|
||||
[LanguageDetectorPrediction("fr", 0.999781)]
|
||||
)
|
||||
_RU_TEXT = "это какой-то английский язык"
|
||||
_RU_EXPECTED_RESULT = LanguageDetectorResult(
|
||||
[LanguageDetectorPrediction("ru", 0.993362)]
|
||||
)
|
||||
_MIXED_TEXT = "分久必合合久必分"
|
||||
_MIXED_EXPECTED_RESULT = LanguageDetectorResult([
|
||||
LanguageDetectorPrediction("zh", 0.505424),
|
||||
LanguageDetectorPrediction("ja", 0.481617),
|
||||
])
|
||||
_TOLERANCE = 1e-6
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class LanguageDetectorTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL)
|
||||
)
|
||||
|
||||
def _expect_language_detector_result_correct(
|
||||
self,
|
||||
actual_result: LanguageDetectorResult,
|
||||
expect_result: LanguageDetectorResult,
|
||||
):
|
||||
for i, prediction in enumerate(actual_result.detections):
|
||||
expected_prediction = expect_result.detections[i]
|
||||
self.assertEqual(
|
||||
prediction.language_code,
|
||||
expected_prediction.language_code,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
prediction.probability,
|
||||
expected_prediction.probability,
|
||||
delta=_TOLERANCE,
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _LanguageDetector.create_from_model_path(self.model_path) as detector:
|
||||
self.assertIsInstance(detector, _LanguageDetector)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _LanguageDetectorOptions(base_options=base_options)
|
||||
with _LanguageDetector.create_from_options(options) as detector:
|
||||
self.assertIsInstance(detector, _LanguageDetector)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Unable to open file at /path/to/invalid/model.tflite"
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path="/path/to/invalid/model.tflite"
|
||||
)
|
||||
options = _LanguageDetectorOptions(base_options=base_options)
|
||||
_LanguageDetector.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, "rb") as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _LanguageDetectorOptions(base_options=base_options)
|
||||
detector = _LanguageDetector.create_from_options(options)
|
||||
self.assertIsInstance(detector, _LanguageDetector)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||
)
|
||||
def test_detect(self, model_file_type, text, expected_result):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, "rb") as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError("model_file_type is invalid.")
|
||||
|
||||
options = _LanguageDetectorOptions(
|
||||
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||
)
|
||||
detector = _LanguageDetector.create_from_options(options)
|
||||
|
||||
# Performs language detection on the input.
|
||||
text_result = detector.detect(text)
|
||||
# Comparing results.
|
||||
self._expect_language_detector_result_correct(text_result, expected_result)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
|
||||
)
|
||||
def test_detect_in_context(self, model_file_type, text, expected_result):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, "rb") as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError("model_file_type is invalid.")
|
||||
|
||||
options = _LanguageDetectorOptions(
|
||||
base_options=base_options, score_threshold=_SCORE_THRESHOLD
|
||||
)
|
||||
with _LanguageDetector.create_from_options(options) as detector:
|
||||
# Performs language detection on the input.
|
||||
text_result = detector.detect(text)
|
||||
# Comparing results.
|
||||
self._expect_language_detector_result_correct(
|
||||
text_result, expected_result
|
||||
)
|
||||
|
||||
def test_allowlist_option(self):
|
||||
# Creates detector.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _LanguageDetectorOptions(
|
||||
base_options=base_options,
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
category_allowlist=["ja"],
|
||||
)
|
||||
with _LanguageDetector.create_from_options(options) as detector:
|
||||
# Performs language detection on the input.
|
||||
text_result = detector.detect(_MIXED_TEXT)
|
||||
# Comparing results.
|
||||
expected_result = LanguageDetectorResult(
|
||||
[LanguageDetectorPrediction("ja", 0.481617)]
|
||||
)
|
||||
self._expect_language_detector_result_correct(
|
||||
text_result, expected_result
|
||||
)
|
||||
|
||||
def test_denylist_option(self):
|
||||
# Creates detector.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _LanguageDetectorOptions(
|
||||
base_options=base_options,
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
category_denylist=["ja"],
|
||||
)
|
||||
with _LanguageDetector.create_from_options(options) as detector:
|
||||
# Performs language detection on the input.
|
||||
text_result = detector.detect(_MIXED_TEXT)
|
||||
# Comparing results.
|
||||
expected_result = LanguageDetectorResult(
|
||||
[LanguageDetectorPrediction("zh", 0.505424)]
|
||||
)
|
||||
self._expect_language_detector_result_correct(
|
||||
text_result, expected_result
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
|
@ -185,3 +185,20 @@ py_test(
|
|||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "face_aligner_test",
|
||||
srcs = ["face_aligner_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:face_aligner",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
],
|
||||
)
|
||||
|
|
190
mediapipe/tasks/python/test/vision/face_aligner_test.py
Normal file
190
mediapipe/tasks/python/test/vision/face_aligner_test.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face aligner."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_aligner
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect.Rect
|
||||
_Image = image_module.Image
|
||||
_FaceAligner = face_aligner.FaceAligner
|
||||
_FaceAlignerOptions = face_aligner.FaceAlignerOptions
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL = 'face_landmarker_v2.task'
|
||||
_LARGE_FACE_IMAGE = 'portrait.jpg'
|
||||
_MODEL_IMAGE_SIZE = 256
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceAlignerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceAligner.create_from_model_path(self.model_path) as aligner:
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
_FaceAligner.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
aligner = _FaceAligner.create_from_options(options)
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_align(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates aligner.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
aligner = _FaceAligner.create_from_options(options)
|
||||
|
||||
# Performs face alignment on the input.
|
||||
alignd_image = aligner.align(self.test_image)
|
||||
self.assertIsInstance(alignd_image, _Image)
|
||||
# Closes the aligner explicitly when the aligner is not used in
|
||||
# a context.
|
||||
aligner.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_align_in_context(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates aligner.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Performs face alignment on the input.
|
||||
alignd_image = aligner.align(self.test_image)
|
||||
self.assertIsInstance(alignd_image, _Image)
|
||||
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_align_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the face.
|
||||
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face alignment on the input.
|
||||
alignd_image = aligner.align(test_image, image_processing_options)
|
||||
self.assertIsInstance(alignd_image, _Image)
|
||||
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_align_succeeds_with_no_face_detected(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest that doesn't contain a human face.
|
||||
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face alignment on the input.
|
||||
alignd_image = aligner.align(test_image, image_processing_options)
|
||||
self.assertIsNone(alignd_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -57,3 +57,22 @@ py_library(
|
|||
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "language_detector",
|
||||
srcs = [
|
||||
"language_detector.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classification_result",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,9 +14,13 @@
|
|||
|
||||
"""MediaPipe Tasks Text API."""
|
||||
|
||||
import mediapipe.tasks.python.text.language_detector
|
||||
import mediapipe.tasks.python.text.text_classifier
|
||||
import mediapipe.tasks.python.text.text_embedder
|
||||
|
||||
LanguageDetector = language_detector.LanguageDetector
|
||||
LanguageDetectorOptions = language_detector.LanguageDetectorOptions
|
||||
LanguageDetectorResult = language_detector.LanguageDetectorResult
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||
TextClassifierResult = text_classifier.TextClassifierResult
|
||||
|
@ -26,5 +30,6 @@ TextEmbedderResult = text_embedder.TextEmbedderResult
|
|||
|
||||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del mediapipe
|
||||
del language_detector
|
||||
del text_classifier
|
||||
del text_embedder
|
||||
|
|
220
mediapipe/tasks/python/text/language_detector.py
Normal file
220
mediapipe/tasks/python/text/language_detector.py
Normal file
|
@ -0,0 +1,220 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe language detector task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
|
||||
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||
|
||||
_ClassificationResult = classification_result_module.ClassificationResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_TextClassifierGraphOptionsProto = (
|
||||
text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||
)
|
||||
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
|
||||
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
|
||||
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||
_TEXT_TAG = 'TEXT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LanguageDetectorResult:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Detection:
|
||||
"""A language code and its probability."""
|
||||
|
||||
# An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||
# "ja"-Latn for Japanese (romaji).
|
||||
language_code: str
|
||||
probability: float
|
||||
|
||||
detections: List[Detection]
|
||||
|
||||
|
||||
def _extract_language_detector_result(
|
||||
classification_result: classification_result_module.ClassificationResult,
|
||||
) -> LanguageDetectorResult:
|
||||
"""Extracts a LanguageDetectorResult from a ClassificationResult."""
|
||||
if len(classification_result.classifications) != 1:
|
||||
raise ValueError(
|
||||
'The LanguageDetector TextClassifierGraph should have exactly one '
|
||||
'classification head.'
|
||||
)
|
||||
languages_and_scores = classification_result.classifications[0]
|
||||
language_detector_result = LanguageDetectorResult([])
|
||||
for category in languages_and_scores.categories:
|
||||
if category.category_name is None:
|
||||
raise ValueError(
|
||||
'LanguageDetector ClassificationResult has a missing language code.'
|
||||
)
|
||||
prediction = LanguageDetectorResult.Detection(
|
||||
category.category_name, category.score
|
||||
)
|
||||
language_detector_result.detections.append(prediction)
|
||||
return language_detector_result
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LanguageDetectorOptions:
|
||||
"""Options for the language detector task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the language detector task.
|
||||
display_names_locale: The locale to use for display names specified through
|
||||
the TFLite Model Metadata.
|
||||
max_results: The maximum number of top-scored classification results to
|
||||
return.
|
||||
score_threshold: Overrides the ones provided in the model metadata. Results
|
||||
below this value are rejected.
|
||||
category_allowlist: Allowlist of category names. If non-empty,
|
||||
classification results whose category name is not in this set will be
|
||||
filtered out. Duplicate or unknown category names are ignored. Mutually
|
||||
exclusive with `category_denylist`.
|
||||
category_denylist: Denylist of category names. If non-empty, classification
|
||||
results whose category name is in this set will be filtered out. Duplicate
|
||||
or unknown category names are ignored. Mutually exclusive with
|
||||
`category_allowlist`.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
display_names_locale: Optional[str] = None
|
||||
max_results: Optional[int] = None
|
||||
score_threshold: Optional[float] = None
|
||||
category_allowlist: Optional[List[str]] = None
|
||||
category_denylist: Optional[List[str]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
|
||||
"""Generates an TextClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
classifier_options_proto = _ClassifierOptionsProto(
|
||||
score_threshold=self.score_threshold,
|
||||
category_allowlist=self.category_allowlist,
|
||||
category_denylist=self.category_denylist,
|
||||
display_names_locale=self.display_names_locale,
|
||||
max_results=self.max_results,
|
||||
)
|
||||
|
||||
return _TextClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
classifier_options=classifier_options_proto,
|
||||
)
|
||||
|
||||
|
||||
class LanguageDetector(base_text_task_api.BaseTextTaskApi):
|
||||
"""Class that predicts the language of an input text.
|
||||
|
||||
This API expects a TFLite model with TFLite Model Metadata that contains the
|
||||
mandatory (described below) input tensors, output tensor, and the language
|
||||
codes in an AssociatedFile.
|
||||
|
||||
Input tensors:
|
||||
(kTfLiteString)
|
||||
- 1 input tensor that is scalar or has shape [1] containing the input
|
||||
string.
|
||||
Output tensor:
|
||||
(kTfLiteFloat32)
|
||||
- 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'LanguageDetector':
|
||||
"""Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
|
||||
Returns:
|
||||
`LanguageDetector` object that's created from the model file and the
|
||||
default `LanguageDetectorOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `LanguageDetector` object from the
|
||||
provided
|
||||
file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = LanguageDetectorOptions(base_options=base_options)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(
|
||||
cls, options: LanguageDetectorOptions
|
||||
) -> 'LanguageDetector':
|
||||
"""Creates the `LanguageDetector` object from language detector options.
|
||||
|
||||
Args:
|
||||
options: Options for the language detector task.
|
||||
|
||||
Returns:
|
||||
`LanguageDetector` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `LanguageDetector` object from
|
||||
`LanguageDetectorOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||
output_streams=[
|
||||
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
|
||||
],
|
||||
task_options=options,
|
||||
)
|
||||
return cls(task_info.generate_graph_config())
|
||||
|
||||
def detect(self, text: str) -> LanguageDetectorResult:
|
||||
"""Predicts the language of the input `text`.
|
||||
|
||||
Args:
|
||||
text: The input text.
|
||||
|
||||
Returns:
|
||||
A `LanguageDetectorResult` object that contains a list of languages and
|
||||
scores.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If language detection failed to run.
|
||||
"""
|
||||
output_packets = self._runner.process(
|
||||
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}
|
||||
)
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
|
||||
)
|
||||
|
||||
classification_result = _ClassificationResult.create_from_pb2(
|
||||
classification_result_proto
|
||||
)
|
||||
return _extract_language_detector_result(classification_result)
|
|
@ -264,3 +264,22 @@ py_library(
|
|||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "face_aligner",
|
||||
srcs = [
|
||||
"face_aligner.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""MediaPipe Tasks Vision API."""
|
||||
|
||||
import mediapipe.tasks.python.vision.core
|
||||
import mediapipe.tasks.python.vision.face_aligner
|
||||
import mediapipe.tasks.python.vision.face_detector
|
||||
import mediapipe.tasks.python.vision.face_landmarker
|
||||
import mediapipe.tasks.python.vision.face_stylizer
|
||||
|
@ -25,7 +26,10 @@ import mediapipe.tasks.python.vision.image_embedder
|
|||
import mediapipe.tasks.python.vision.image_segmenter
|
||||
import mediapipe.tasks.python.vision.interactive_segmenter
|
||||
import mediapipe.tasks.python.vision.object_detector
|
||||
import mediapipe.tasks.python.vision.pose_landmarker
|
||||
|
||||
FaceAligner = face_aligner.FaceAligner
|
||||
FaceAlignerOptions = face_aligner.FaceAlignerOptions
|
||||
FaceDetector = face_detector.FaceDetector
|
||||
FaceDetectorOptions = face_detector.FaceDetectorOptions
|
||||
FaceDetectorResult = face_detector.FaceDetectorResult
|
||||
|
@ -41,6 +45,7 @@ GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
|
|||
HandLandmarker = hand_landmarker.HandLandmarker
|
||||
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||
HandLandmarksConnections = hand_landmarker.HandLandmarksConnections
|
||||
ImageClassifier = image_classifier.ImageClassifier
|
||||
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
ImageClassifierResult = image_classifier.ImageClassifierResult
|
||||
|
@ -54,10 +59,16 @@ InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
|||
InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest
|
||||
ObjectDetector = object_detector.ObjectDetector
|
||||
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||
ObjectDetectorResult = object_detector.ObjectDetectorResult
|
||||
PoseLandmarker = pose_landmarker.PoseLandmarker
|
||||
PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions
|
||||
PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult
|
||||
PoseLandmarksConnections = pose_landmarker.PoseLandmarksConnections
|
||||
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
||||
|
||||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del core
|
||||
del face_aligner
|
||||
del face_detector
|
||||
del face_landmarker
|
||||
del face_stylizer
|
||||
|
@ -68,4 +79,5 @@ del image_embedder
|
|||
del image_segmenter
|
||||
del interactive_segmenter
|
||||
del object_detector
|
||||
del pose_landmarker
|
||||
del mediapipe
|
||||
|
|
158
mediapipe/tasks/python/vision/face_aligner.py
Normal file
158
mediapipe/tasks/python/vision/face_aligner.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe face aligner task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_FaceStylizerGraphOptionsProto = (
|
||||
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
|
||||
)
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_FACE_ALIGNMENT_IMAGE_NAME = 'face_alignment'
|
||||
_FACE_ALIGNMENT_IMAGE_TAG = 'FACE_ALIGNMENT'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FaceAlignerOptions:
|
||||
"""Options for the face aligner task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the face aligner task.
|
||||
"""
|
||||
|
||||
base_options: _BaseOptions
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
|
||||
"""Generates a FaceStylizerOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False
|
||||
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
|
||||
|
||||
|
||||
class FaceAligner(base_vision_task_api.BaseVisionTaskApi):
|
||||
"""Class that performs face alignment on images."""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'FaceAligner':
|
||||
"""Creates a `FaceAligner` object from a face landmarker task bundle and the default `FaceAlignerOptions`.
|
||||
|
||||
Note that the created `FaceAligner` instance is in image mode, for
|
||||
aligning one face on a single image input.
|
||||
|
||||
Args:
|
||||
model_path: Path to the face landmarker task bundle.
|
||||
|
||||
Returns:
|
||||
`FaceAligner` object that's created from the model file and the default
|
||||
`FaceAlignerOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `FaceAligner` object from the provided
|
||||
file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = FaceAlignerOptions(base_options=base_options)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls, options: FaceAlignerOptions) -> 'FaceAligner':
|
||||
"""Creates the `FaceAligner` object from face aligner options.
|
||||
|
||||
Args:
|
||||
options: Options for the face aligner task.
|
||||
|
||||
Returns:
|
||||
`FaceAligner` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `FaceAligner` object from
|
||||
`FaceAlignerOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_FACE_ALIGNMENT_IMAGE_TAG, _FACE_ALIGNMENT_IMAGE_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
task_info.generate_graph_config(enable_flow_limiting=False),
|
||||
_RunningMode.IMAGE,
|
||||
None,
|
||||
)
|
||||
|
||||
def align(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> image_module.Image:
|
||||
"""Performs face alignment on the provided MediaPipe Image.
|
||||
|
||||
Only use this method when the FaceAligner is created with the image
|
||||
running mode.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
The aligned face image. The aligned output image size is the same as the
|
||||
model output size. None if no face is detected on the input image.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If face alignment failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, image
|
||||
)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
|
||||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
if output_packets[_FACE_ALIGNMENT_IMAGE_NAME].is_empty():
|
||||
return None
|
||||
return packet_getter.get_image(output_packets[_FACE_ALIGNMENT_IMAGE_NAME])
|
|
@ -2939,7 +2939,7 @@ class FaceLandmarkerOptions:
|
|||
Attributes:
|
||||
base_options: Base options for the face landmarker task.
|
||||
running_mode: The running mode of the task. Default to the image mode.
|
||||
HandLandmarker has three running modes: 1) The image mode for detecting
|
||||
FaceLandmarker has three running modes: 1) The image mode for detecting
|
||||
face landmarks on single image inputs. 2) The video mode for detecting
|
||||
face landmarks on the decoded frames of a video. 3) The live stream mode
|
||||
for detecting face landmarks on the live stream of input data, such as
|
||||
|
|
|
@ -82,6 +82,65 @@ class HandLandmark(enum.IntEnum):
|
|||
PINKY_TIP = 20
|
||||
|
||||
|
||||
class HandLandmarksConnections:
|
||||
"""The connections between hand landmarks."""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Connection:
|
||||
"""The connection class for hand landmarks."""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
HAND_PALM_CONNECTIONS: List[Connection] = [
|
||||
Connection(0, 1),
|
||||
Connection(1, 5),
|
||||
Connection(9, 13),
|
||||
Connection(13, 17),
|
||||
Connection(5, 9),
|
||||
Connection(0, 17),
|
||||
]
|
||||
|
||||
HAND_THUMB_CONNECTIONS: List[Connection] = [
|
||||
Connection(1, 2),
|
||||
Connection(2, 3),
|
||||
Connection(3, 4),
|
||||
]
|
||||
|
||||
HAND_INDEX_FINGER_CONNECTIONS: List[Connection] = [
|
||||
Connection(5, 6),
|
||||
Connection(6, 7),
|
||||
Connection(7, 8),
|
||||
]
|
||||
|
||||
HAND_MIDDLE_FINGER_CONNECTIONS: List[Connection] = [
|
||||
Connection(9, 10),
|
||||
Connection(10, 11),
|
||||
Connection(11, 12),
|
||||
]
|
||||
|
||||
HAND_RING_FINGER_CONNECTIONS: List[Connection] = [
|
||||
Connection(13, 14),
|
||||
Connection(14, 15),
|
||||
Connection(15, 16),
|
||||
]
|
||||
|
||||
HAND_PINKY_FINGER_CONNECTIONS: List[Connection] = [
|
||||
Connection(17, 18),
|
||||
Connection(18, 19),
|
||||
Connection(19, 20),
|
||||
]
|
||||
|
||||
HAND_CONNECTIONS: List[Connection] = (
|
||||
HAND_PALM_CONNECTIONS +
|
||||
HAND_THUMB_CONNECTIONS +
|
||||
HAND_INDEX_FINGER_CONNECTIONS +
|
||||
HAND_MIDDLE_FINGER_CONNECTIONS +
|
||||
HAND_RING_FINGER_CONNECTIONS +
|
||||
HAND_PINKY_FINGER_CONNECTIONS
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandLandmarkerResult:
|
||||
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
|
||||
|
|
|
@ -88,7 +88,7 @@ class InteractiveSegmenterOptions:
|
|||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
||||
"""Generates an ImageSegmenterGraphOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False
|
||||
segmenter_options_proto = _SegmenterOptionsProto()
|
||||
|
|
|
@ -132,6 +132,55 @@ def _build_landmarker_result(
|
|||
return pose_landmarker_result
|
||||
|
||||
|
||||
class PoseLandmarksConnections:
|
||||
"""The connections between pose landmarks."""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Connection:
|
||||
"""The connection class for pose landmarks."""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
POSE_LANDMARKS: List[Connection] = [
|
||||
Connection(0, 1),
|
||||
Connection(1, 2),
|
||||
Connection(2, 3),
|
||||
Connection(3, 7),
|
||||
Connection(0, 4),
|
||||
Connection(4, 5),
|
||||
Connection(5, 6),
|
||||
Connection(6, 8),
|
||||
Connection(9, 10),
|
||||
Connection(11, 12),
|
||||
Connection(11, 13),
|
||||
Connection(13, 15),
|
||||
Connection(15, 17),
|
||||
Connection(15, 19),
|
||||
Connection(15, 21),
|
||||
Connection(17, 19),
|
||||
Connection(12, 14),
|
||||
Connection(14, 16),
|
||||
Connection(16, 18),
|
||||
Connection(16, 20),
|
||||
Connection(16, 22),
|
||||
Connection(18, 20),
|
||||
Connection(11, 23),
|
||||
Connection(12, 24),
|
||||
Connection(23, 24),
|
||||
Connection(23, 25),
|
||||
Connection(24, 26),
|
||||
Connection(25, 27),
|
||||
Connection(26, 28),
|
||||
Connection(27, 29),
|
||||
Connection(28, 30),
|
||||
Connection(29, 31),
|
||||
Connection(30, 32),
|
||||
Connection(27, 31),
|
||||
Connection(28, 32)
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PoseLandmarkerOptions:
|
||||
"""Options for the pose landmarker task.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/vector.h"
|
||||
#include "mediapipe/util/color.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
|
|||
DrawGradientLine(annotation);
|
||||
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
|
||||
DrawArrow(annotation);
|
||||
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
|
||||
DrawScribble(annotation);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
|
||||
}
|
||||
|
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
|
|||
}
|
||||
|
||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
||||
const auto& point = annotation.point();
|
||||
DrawPoint(annotation.point(), annotation);
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
|
||||
const RenderAnnotation& annotation) {
|
||||
int x = -1;
|
||||
int y = -1;
|
||||
if (point.normalized()) {
|
||||
|
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
|
|||
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
|
||||
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
|
||||
DrawPoint(point, annotation);
|
||||
}
|
||||
}
|
||||
|
||||
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
|
||||
int x_start = -1;
|
||||
int y_start = -1;
|
||||
|
|
|
@ -96,6 +96,11 @@ class AnnotationRenderer {
|
|||
|
||||
// Draws a point on the image as described in the annotation.
|
||||
void DrawPoint(const RenderAnnotation& annotation);
|
||||
void DrawPoint(const RenderAnnotation::Point& point,
|
||||
const RenderAnnotation& annotation);
|
||||
|
||||
// Draws scribbles on the image as described in the annotation.
|
||||
void DrawScribble(const RenderAnnotation& annotation);
|
||||
|
||||
// Draws a line segment on the image as described in the annotation.
|
||||
void DrawLine(const RenderAnnotation& annotation);
|
||||
|
|
|
@ -131,6 +131,10 @@ message RenderAnnotation {
|
|||
optional Color color2 = 7;
|
||||
}
|
||||
|
||||
message Scribble {
|
||||
repeated Point point = 1;
|
||||
}
|
||||
|
||||
message Arrow {
|
||||
// The arrow head will be drawn at (x_end, y_end).
|
||||
optional double x_start = 1;
|
||||
|
@ -192,6 +196,7 @@ message RenderAnnotation {
|
|||
RoundedRectangle rounded_rectangle = 9;
|
||||
FilledRoundedRectangle filled_rounded_rectangle = 10;
|
||||
GradientLine gradient_line = 14;
|
||||
Scribble scribble = 15;
|
||||
}
|
||||
|
||||
// Thickness for drawing the annotation.
|
||||
|
|
Loading…
Reference in New Issue
Block a user