Merge branch 'ios-normalized-keypoint-hash' into ios-async-calls-fixes

This commit is contained in:
Prianka Liz Kariat 2023-05-04 17:22:48 +05:30
commit db5fd168b6
44 changed files with 1661 additions and 67 deletions

View File

@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input,
static_cast<int>(range_max) == 255);
}
auto input_frame = input.GetGpuBuffer().GetReadView<FrameBuffer>();
auto input_frame =
input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView<FrameBuffer>();
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2],

View File

@ -1285,12 +1285,14 @@ cc_library(
srcs = ["flat_color_image_calculator.cc"],
deps = [
":flat_color_image_calculator_cc_proto",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",

View File

@ -15,14 +15,13 @@
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
@ -32,6 +31,7 @@ namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::SideOutput;
} // namespace
// A calculator for generating an image filled with a single color.
@ -45,7 +45,8 @@ using ::mediapipe::api2::Output;
//
// Outputs:
// IMAGE (Image)
// Image filled with the requested color.
// Image filled with the requested color. Can be either an output_stream
// or an output_side_packet.
//
// Example useage:
// node {
@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node {
public:
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
static constexpr Input<Color>::Optional kInColor{"COLOR"};
static constexpr Output<Image> kOutImage{"IMAGE"};
static constexpr Output<Image>::Optional kOutImage{"IMAGE"};
static constexpr SideOutput<Image>::Optional kOutSideImage{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage);
static absl::Status UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node {
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
<< "Either set COLOR input stream, or set through options";
RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected())
<< "Set IMAGE either as output stream, or as output side packet";
RET_CHECK(!kOutSideImage(cc).IsConnected() ||
(options.has_output_height() && options.has_output_width()))
<< "Set size through options, when setting IMAGE as output side packet";
return absl::OkStatus();
}
@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node {
absl::Status Process(CalculatorContext* cc) override;
private:
std::optional<std::shared_ptr<ImageFrame>> CreateOutputFrame(
CalculatorContext* cc);
bool use_dimension_from_option_ = false;
bool use_color_from_option_ = false;
};
@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
use_dimension_from_option_ = !kInImage(cc).IsConnected();
use_color_from_option_ = !kInColor(cc).IsConnected();
if (!kOutImage(cc).IsConnected()) {
std::optional<std::shared_ptr<ImageFrame>> output_frame =
CreateOutputFrame(cc);
if (output_frame.has_value()) {
kOutSideImage(cc).Set(Image(output_frame.value()));
}
}
return absl::OkStatus();
}
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
if (kOutImage(cc).IsConnected()) {
std::optional<std::shared_ptr<ImageFrame>> output_frame =
CreateOutputFrame(cc);
if (output_frame.has_value()) {
kOutImage(cc).Send(Image(output_frame.value()));
}
}
return absl::OkStatus();
}
std::optional<std::shared_ptr<ImageFrame>>
FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
int output_height = -1;
@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
output_height = input_image.height();
output_width = input_image.width();
} else {
return absl::OkStatus();
return std::nullopt;
}
Color color;
@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
} else if (!kInColor(cc).IsEmpty()) {
color = kInColor(cc).Get();
} else {
return absl::OkStatus();
return std::nullopt;
}
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
kOutImage(cc).Send(Image(output_frame));
return absl::OkStatus();
return output_frame;
}
} // namespace mediapipe

View File

@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
}
}
TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
output_side_packet: "IMAGE:out_packet"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 1
output_height: 1
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
MP_ASSERT_OK(runner.Run());
const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get<Image>();
EXPECT_EQ(image.width(), 1);
EXPECT_EQ(image.height(), 1);
auto image_frame = image.GetImageFrameSharedPtr();
const uint8_t* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 100);
EXPECT_EQ(pixel_data[1], 200);
EXPECT_EQ(pixel_data[2], 255);
}
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
HasSubstr("Either set COLOR input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
output_stream: "IMAGE:out_image"
output_side_packet: "IMAGE:out_packet"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 1
output_height: 1
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
ASSERT_THAT(
runner.Run().message(),
HasSubstr("Set IMAGE either as output stream, or as output side packet"));
}
TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_side_packet: "IMAGE:out_packet"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Set size through options, when setting IMAGE as "
"output side packet"));
}
} // namespace
} // namespace mediapipe

View File

@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) {
double target_aspect_ratio = 0.5;
int expect_width = 14;
int expect_height = input_height;
auto test_frame = absl::make_unique<ImageFrame>(/*format=*/ImageFormat::SRGB,
input_width, input_height);
ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width,
input_height);
cv::Mat mat = formats::MatView(&test_frame);
mat = cv::Scalar(0, 0, 0);
PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(),
PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(),
target_aspect_ratio,
/*scale_to_multiple_of_two=*/true);
ImageFrame result_frame;
MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame));
MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame));
EXPECT_EQ(result_frame.Width(), expect_width);
EXPECT_EQ(result_frame.Height(), expect_height);
}

View File

@ -113,11 +113,11 @@ class Image {
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#endif // !MEDIAPIPE_DISABLE_GPU
// Get a GPU view. Automatically uploads from CPU if needed.
const mediapipe::GpuBuffer GetGpuBuffer() const {
#if !MEDIAPIPE_DISABLE_GPU
if (use_gpu_ == false) ConvertToGpu();
#endif // !MEDIAPIPE_DISABLE_GPU
// Provides access to the underlying GpuBuffer storage.
// Automatically uploads from CPU to GPU if needed and requested through the
// `upload_to_gpu` argument.
const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const {
if (!use_gpu_ && upload_to_gpu) ConvertToGpu();
return gpu_buffer_;
}

View File

@ -0,0 +1,31 @@
"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly."""
def _copy_header_impl(ctx):
source = ctx.attr.source.replace("//", "").replace(":", "/")
files = []
for dep in ctx.attr.deps:
for header in dep[CcInfo].compilation_context.direct_headers:
if (header.short_path == source):
files.append(header)
if len(files) != 1:
fail("Expected exactly 1 source, got ", str(files))
dest_file = ctx.actions.declare_file(ctx.attr.filename)
# Use expand_template() with no substitutions as a simple copier.
ctx.actions.expand_template(
template = files[0],
output = dest_file,
substitutions = {},
)
return [DefaultInfo(files = depset([dest_file]))]
copy_header = rule(
implementation = _copy_header_impl,
attrs = {
"filename": attr.string(),
"source": attr.string(),
"deps": attr.label_list(providers = [CcInfo]),
},
output_to_genfiles = True,
outputs = {"out": "%{filename}"},
)

View File

@ -791,6 +791,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
"@stblib//:stb_image_write",
],

View File

@ -26,6 +26,7 @@
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator.pb.h"
@ -311,6 +312,13 @@ std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
// Returns the path to the output if successful.
absl::StatusOr<std::string> SavePngTestOutput(
const mediapipe::ImageFrame& image, absl::string_view prefix) {
absl::flat_hash_set<ImageFormat::Format> supported_formats = {
ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA,
ImageFormat::LAB8, ImageFormat::SBGRA};
if (!supported_formats.contains(image.Format())) {
return absl::CancelledError(
absl::StrFormat("Format %d can not be saved to PNG.", image.Format()));
}
std::string now_string = absl::FormatTime(absl::Now());
std::string output_relative_path =
absl::StrCat(prefix, "_", now_string, ".png");

View File

@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model):
self._num_classes = num_classes
self._model = self._build_model()
checkpoint_folder = self._model_spec.downloaded_files.get_path()
checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200')
checkpoint_file = os.path.join(
checkpoint_folder, self._model_spec.checkpoint_name
)
self.load_checkpoint(checkpoint_file)
self._model.summary()
self.loss_trackers = [
@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model):
num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3
),
backbone=configs.backbones.Backbone(
type='mobilenet', mobilenet=configs.backbones.MobileNet()
type='mobilenet',
mobilenet=configs.backbones.MobileNet(
model_id=self._model_spec.model_id
),
),
decoder=configs.decoders.Decoder(
type='fpn',

View File

@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles(
is_folder=True,
)
MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles(
'object_detector/mobilenetmultiavg',
'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz',
is_folder=True,
)
@dataclasses.dataclass
class ModelSpec(object):
@ -38,13 +44,25 @@ class ModelSpec(object):
stddev_rgb = (127.5,)
downloaded_files: file_util.DownloadedFiles
checkpoint_name: str
input_image_shape: List[int]
model_id: str
mobilenet_v2_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_V2_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3],
model_id='MobileNetV2',
)
mobilenet_multi_avg_spec = functools.partial(
ModelSpec,
downloaded_files=MOBILENET_MULTI_AVG_FILES,
checkpoint_name='ckpt-277200',
input_image_shape=[256, 256, 3],
model_id='MobileNetMultiAVG',
)
@ -53,6 +71,7 @@ class SupportedModels(enum.Enum):
"""Predefined object detector model specs supported by Model Maker."""
MOBILENET_V2 = mobilenet_v2_spec
MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec
@classmethod
def get(cls, spec: 'SupportedModels') -> 'ModelSpec':

View File

@ -93,3 +93,8 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_proto_library(
name = "transformer_params_proto",
srcs = ["transformer_params.proto"],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "TransformerParametersProto";
// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf)
message TransformerParameters {
// Batch size of tensors.
int32 batch_size = 1;
// Maximum sequence length of the input/output tensor.
int32 max_seq_length = 2;
// Embedding dimension (or model dimension), `d_model` in the paper.
// `d_k` == `d_v` == `d_model`/`h`.
int32 embedding_dim = 3;
// Hidden dimension used in the feedforward layer, `d_ff` in the paper.
int32 hidden_dimension = 4;
// Head dimension, `d_k` or `d_v` in the paper.
int32 head_dimension = 5;
// Number of heads, `h` in the paper.
int32 num_heads = 6;
// Number of stacked transformers, `N` in the paper.
int32 num_stacks = 7;
}

View File

@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph {
auto matrix = preprocessing.Out(kMatrixTag);
auto image_size = preprocessing.Out(kImageSizeTag);
// Face detection model inferece.
// Face detection model inference.
auto& inference = AddInference(
model_resources, subgraph_options.base_options().acceleration(), graph);
preprocessed_tensors >> inference.In(kTensorsTag);

View File

@ -199,7 +199,9 @@ void ConfigureTensorsToImageCalculator(
// STYLIZED_IMAGE - mediapipe::Image
// The face stylization output image.
// FACE_ALIGNMENT - mediapipe::Image
// The face alignment output image.
// The aligned face image that is fed to the face stylization model to
// perform stylization. Also useful for preparing face stylization training
// data.
// IMAGE - mediapipe::Image
// The input image that the face landmarker runs on and has the pixel data
// stored on the target storage (CPU vs GPU).
@ -211,6 +213,7 @@ void ConfigureTensorsToImageCalculator(
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "IMAGE:image_out"
// output_stream: "STYLIZED_IMAGE:stylized_image"
// output_stream: "FACE_ALIGNMENT:face_alignment_image"
// options {
// [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext]
// {
@ -248,7 +251,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
->mutable_face_landmarker_graph_options(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
const ModelResources* face_stylizer_model_resources;
const ModelResources* face_stylizer_model_resources = nullptr;
if (output_stylized) {
ASSIGN_OR_RETURN(
const auto* model_resources,
@ -332,7 +335,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
auto face_rect = face_to_rect.Out(kNormRectTag);
std::optional<Source<Image>> face_alignment;
// Output face alignment only.
// Output aligned face only.
// In this case, the face stylization model inference is not required.
// However, to keep consistent with the inference preprocessing steps, the
// ImageToTensorCalculator is still used to perform image rotation,

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
using components::containers::NormalizedKeypoint;
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
@ -115,7 +118,7 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
case RegionOfInterest::Format::kUnspecified:
return absl::InvalidArgumentError(
"RegionOfInterest format not specified");
case RegionOfInterest::Format::kKeyPoint:
case RegionOfInterest::Format::kKeyPoint: {
RET_CHECK(roi.keypoint.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
@ -124,6 +127,19 @@ absl::StatusOr<RenderData> ConvertRoiToRenderData(const RegionOfInterest& roi) {
point->set_x(roi.keypoint->x);
point->set_y(roi.keypoint->y);
return result;
}
case RegionOfInterest::Format::kScribble: {
RET_CHECK(roi.scribble.has_value());
auto* annotation = result.add_render_annotations();
annotation->mutable_color()->set_r(255);
for (const NormalizedKeypoint& keypoint : *(roi.scribble)) {
auto* point = annotation->mutable_scribble()->add_point();
point->set_normalized(true);
point->set_x(keypoint.x);
point->set_y(keypoint.y);
}
return result;
}
}
return absl::UnimplementedError("Unrecognized format");
}

View File

@ -53,6 +53,7 @@ struct RegionOfInterest {
enum class Format {
kUnspecified = 0, // Format not specified.
kKeyPoint = 1, // Using keypoint to represent ROI.
kScribble = 2, // Using scribble to represent ROI.
};
// Specifies the format used to specify the region-of-interest. Note that
@ -61,8 +62,13 @@ struct RegionOfInterest {
Format format = Format::kUnspecified;
// Represents the ROI in keypoint format, this should be non-nullopt if
// `format` is `KEYPOINT`.
// `format` is `kKeyPoint`.
std::optional<components::containers::NormalizedKeypoint> keypoint;
// Represents the ROI in scribble format, this should be non-nullopt if
// `format` is `kScribble`.
std::optional<std::vector<components::containers::NormalizedKeypoint>>
scribble;
};
// Performs interactive segmentation on images.

View File

@ -18,9 +18,12 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) {
struct InteractiveSegmenterTestParams {
std::string test_name;
RegionOfInterest::Format format;
NormalizedKeypoint roi;
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
absl::string_view golden_mask_file;
float similarity_threshold;
};
using SucceedSegmentationWithRoi =
::testing::TestWithParam<InteractiveSegmenterTestParams>;
class SucceedSegmentationWithRoi
: public ::testing::TestWithParam<InteractiveSegmenterTestParams> {
public:
absl::StatusOr<RegionOfInterest> TestParamsToTaskOptions() {
const InteractiveSegmenterTestParams& params = GetParam();
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
switch (params.format) {
case (RegionOfInterest::Format::kKeyPoint): {
interaction_roi.keypoint = std::get<NormalizedKeypoint>(params.roi);
break;
}
case (RegionOfInterest::Format::kScribble): {
interaction_roi.scribble =
std::get<std::vector<NormalizedKeypoint>>(params.roi);
break;
}
default: {
return absl::InvalidArgumentError("Unknown ROI format");
}
}
return interaction_roi;
}
};
TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
TestParamsToTaskOptions());
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
}
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const auto& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi,
TestParamsToTaskOptions());
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
RegionOfInterest interaction_roi;
interaction_roi.format = params.format;
interaction_roi.keypoint = params.roi;
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
INSTANTIATE_TEST_SUITE_P(
SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi,
::testing::ValuesIn<InteractiveSegmenterTestParams>(
{{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
{// Keypoint input.
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity}}),
kGoldenMaskSimilarity},
// Scribble input.
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.44, 0.70},
NormalizedKeypoint{0.44, 0.71},
NormalizedKeypoint{0.44, 0.72}},
kCatsAndDogsMaskDog1, 0.84f},
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.66, 0.66},
NormalizedKeypoint{0.66, 0.67},
NormalizedKeypoint{0.66, 0.68}},
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; });

View File

@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
->mutable_model_asset(),
is_copy);
}
pose_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
if (options->base_options().acceleration().has_gpu()) {
core::proto::Acceleration gpu_accel;
gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true);
pose_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(gpu_accel);
} else {
pose_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
}
pose_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
auto* pose_landmarks_detector_graph_options =

View File

@ -28,7 +28,12 @@
return self;
}
// TODO: Implement hash
- (NSUInteger)hash {
NSUInteger nonNullPropertiesHash =
@(self.location.x).hash ^ @(self.location.y).hash ^ @(self.score).hash;
return self.label ? nonNullPropertiesHash ^ self.label.hash : nonNullPropertiesHash;
}
- (BOOL)isEqual:(nullable id)object {
if (!object) {

View File

@ -180,6 +180,7 @@ android_library(
srcs = [
"poselandmarker/PoseLandmarker.java",
"poselandmarker/PoseLandmarkerResult.java",
"poselandmarker/PoseLandmarksConnections.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
@ -212,6 +213,7 @@ android_library(
"handlandmarker/HandLandmark.java",
"handlandmarker/HandLandmarker.java",
"handlandmarker/HandLandmarkerResult.java",
"handlandmarker/HandLandmarksConnections.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",

View File

@ -77,11 +77,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
if (!normRectStreamName.isEmpty()) {
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
}
return runner.process(inputPackets);
}
@ -105,11 +107,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
if (!normRectStreamName.isEmpty()) {
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
}
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
@ -133,11 +137,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
if (!normRectStreamName.isEmpty()) {
inputPackets.put(
normRectStreamName,
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
}
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}

View File

@ -0,0 +1,105 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.handlandmarker;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** Hand landmarks connection constants. */
public final class HandLandmarksConnections {
/** Value class representing hand landmarks connection. */
@AutoValue
public abstract static class Connection {
static Connection create(int start, int end) {
return new AutoValue_HandLandmarksConnections_Connection(start, end);
}
public abstract int start();
public abstract int end();
}
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_PALM_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(0, 1),
Connection.create(0, 5),
Connection.create(9, 13),
Connection.create(13, 17),
Connection.create(5, 9),
Connection.create(0, 17))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_THUMB_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(1, 2), Connection.create(2, 3), Connection.create(3, 4))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_INDEX_FINGER_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(5, 6), Connection.create(6, 7), Connection.create(7, 8))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_MIDDLE_FINGER_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(9, 10), Connection.create(10, 11), Connection.create(11, 12))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_RING_FINGER_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(13, 14),
Connection.create(14, 15),
Connection.create(15, 16))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_PINKY_FINGER_CONNECTIONS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(17, 18),
Connection.create(18, 19),
Connection.create(19, 20))));
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> HAND_CONNECTIONS =
Collections.unmodifiableSet(
Stream.of(
HAND_PALM_CONNECTIONS.stream(),
HAND_THUMB_CONNECTIONS.stream(),
HAND_INDEX_FINGER_CONNECTIONS.stream(),
HAND_MIDDLE_FINGER_CONNECTIONS.stream(),
HAND_RING_FINGER_CONNECTIONS.stream(),
HAND_PINKY_FINGER_CONNECTIONS.stream())
.flatMap(i -> i)
.collect(Collectors.toSet()));
private HandLandmarksConnections() {}
}

View File

@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
/** The Region-Of-Interest (ROI) to interact with. */
public static class RegionOfInterest {
private NormalizedKeypoint keypoint;
private List<NormalizedKeypoint> scribble;
private RegionOfInterest() {}
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
roi.keypoint = keypoint;
return roi;
}
/**
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
* user wants to segment.
*/
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
RegionOfInterest roi = new RegionOfInterest();
roi.scribble = scribble;
return roi;
}
}
/**
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
.setX(roi.keypoint.x())
.setY(roi.keypoint.y())))
.build();
} else if (roi.scribble != null) {
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
for (NormalizedKeypoint p : roi.scribble) {
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
}
return builder
.addRenderAnnotations(
RenderAnnotation.newBuilder()
.setColor(Color.newBuilder().setR(255))
.setScribble(scribbleBuilder))
.build();
}
throw new IllegalArgumentException(

View File

@ -0,0 +1,80 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.poselandmarker;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/** Pose landmarks connection constants. */
public final class PoseLandmarksConnections {
/** Value class representing pose landmarks connection. */
@AutoValue
public abstract static class Connection {
static Connection create(int start, int end) {
return new AutoValue_PoseLandmarksConnections_Connection(start, end);
}
public abstract int start();
public abstract int end();
}
@SuppressWarnings("ConstantCaseForConstants")
public static final Set<Connection> POSE_LANDMARKS =
Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(
Connection.create(0, 1),
Connection.create(1, 2),
Connection.create(2, 3),
Connection.create(3, 7),
Connection.create(0, 4),
Connection.create(4, 5),
Connection.create(5, 6),
Connection.create(6, 8),
Connection.create(9, 10),
Connection.create(11, 12),
Connection.create(11, 13),
Connection.create(13, 15),
Connection.create(15, 17),
Connection.create(15, 19),
Connection.create(15, 21),
Connection.create(17, 19),
Connection.create(12, 14),
Connection.create(14, 16),
Connection.create(16, 18),
Connection.create(16, 20),
Connection.create(16, 22),
Connection.create(18, 20),
Connection.create(11, 23),
Connection.create(12, 24),
Connection.create(23, 24),
Connection.create(23, 25),
Connection.create(24, 26),
Connection.create(25, 27),
Connection.create(26, 28),
Connection.create(27, 29),
Connection.create(28, 30),
Connection.create(29, 31),
Connection.create(30, 32),
Connection.create(27, 31),
Connection.create(28, 32))));
private PoseLandmarksConnections() {}
}

View File

@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link InteractiveSegmenter}. */
@RunWith(Suite.class)
@SuiteClasses({
InteractiveSegmenterTest.General.class,
InteractiveSegmenterTest.KeypointRoi.class,
InteractiveSegmenterTest.ScribbleRoi.class,
})
public class InteractiveSegmenterTest {
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
private static final int MAGNIFICATION_FACTOR = 10;
@RunWith(AndroidJUnit4.class)
public static final class General extends InteractiveSegmenterTest {
public static final class KeypointRoi extends InteractiveSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
}
}
@RunWith(AndroidJUnit4.class)
public static final class ScribbleRoi extends InteractiveSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(scribble);
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
assertThat(actualResult.categoryMask().isPresent()).isTrue();
}
@Test
public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(scribble);
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
assertThat(confidenceMasks.size()).isEqualTo(2);
}
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);

View File

@ -39,9 +39,11 @@ _Landmark = landmark_module.Landmark
class LandmarksDetectionResult:
"""Represents the landmarks detection result.
Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A
list of `Category` objects. world_landmarks : A list of `Landmark` objects.
rect : A `NormalizedRect` object.
Attributes:
landmarks: A list of `NormalizedLandmark` objects.
categories: A list of `Category` objects.
world_landmarks: A list of `Landmark` objects.
rect: A `NormalizedRect` object.
"""
landmarks: Optional[List[_NormalizedLandmark]]

View File

@ -49,3 +49,18 @@ py_test(
"//mediapipe/tasks/python/text:text_embedder",
],
)
py_test(
name = "language_detector_test",
srcs = ["language_detector_test.py"],
data = [
"//mediapipe/tasks/testdata/text:language_detector",
],
deps = [
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:language_detector",
],
)

View File

@ -0,0 +1,228 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for language detector."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import language_detector
LanguageDetectorResult = language_detector.LanguageDetectorResult
LanguageDetectorPrediction = (
language_detector.LanguageDetectorResult.Detection
)
_BaseOptions = base_options_module.BaseOptions
_Category = category.Category
_Classifications = classification_result_module.Classifications
_LanguageDetector = language_detector.LanguageDetector
_LanguageDetectorOptions = language_detector.LanguageDetectorOptions
_LANGUAGE_DETECTOR_MODEL = "language_detector.tflite"
_TEST_DATA_DIR = "mediapipe/tasks/testdata/text"
_SCORE_THRESHOLD = 0.3
_EN_TEXT = "To be, or not to be, that is the question"
_EN_EXPECTED_RESULT = LanguageDetectorResult(
[LanguageDetectorPrediction("en", 0.999856)]
)
_FR_TEXT = (
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent."
)
_FR_EXPECTED_RESULT = LanguageDetectorResult(
[LanguageDetectorPrediction("fr", 0.999781)]
)
_RU_TEXT = "это какой-то английский язык"
_RU_EXPECTED_RESULT = LanguageDetectorResult(
[LanguageDetectorPrediction("ru", 0.993362)]
)
_MIXED_TEXT = "分久必合合久必分"
_MIXED_EXPECTED_RESULT = LanguageDetectorResult([
LanguageDetectorPrediction("zh", 0.505424),
LanguageDetectorPrediction("ja", 0.481617),
])
_TOLERANCE = 1e-6
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class LanguageDetectorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL)
)
def _expect_language_detector_result_correct(
self,
actual_result: LanguageDetectorResult,
expect_result: LanguageDetectorResult,
):
for i, prediction in enumerate(actual_result.detections):
expected_prediction = expect_result.detections[i]
self.assertEqual(
prediction.language_code,
expected_prediction.language_code,
)
self.assertAlmostEqual(
prediction.probability,
expected_prediction.probability,
delta=_TOLERANCE,
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _LanguageDetector.create_from_model_path(self.model_path) as detector:
self.assertIsInstance(detector, _LanguageDetector)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(base_options=base_options)
with _LanguageDetector.create_from_options(options) as detector:
self.assertIsInstance(detector, _LanguageDetector)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, "Unable to open file at /path/to/invalid/model.tflite"
):
base_options = _BaseOptions(
model_asset_path="/path/to/invalid/model.tflite"
)
options = _LanguageDetectorOptions(base_options=base_options)
_LanguageDetector.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, "rb") as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _LanguageDetectorOptions(base_options=base_options)
detector = _LanguageDetector.create_from_options(options)
self.assertIsInstance(detector, _LanguageDetector)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
)
def test_detect(self, model_file_type, text, expected_result):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, "rb") as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError("model_file_type is invalid.")
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD
)
detector = _LanguageDetector.create_from_options(options)
# Performs language detection on the input.
text_result = detector.detect(text)
# Comparing results.
self._expect_language_detector_result_correct(text_result, expected_result)
# Closes the detector explicitly when the detector is not used in
# a context.
detector.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT),
(ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT),
(ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT),
)
def test_detect_in_context(self, model_file_type, text, expected_result):
# Creates detector.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, "rb") as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError("model_file_type is invalid.")
options = _LanguageDetectorOptions(
base_options=base_options, score_threshold=_SCORE_THRESHOLD
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(text)
# Comparing results.
self._expect_language_detector_result_correct(
text_result, expected_result
)
def test_allowlist_option(self):
# Creates detector.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(
base_options=base_options,
score_threshold=_SCORE_THRESHOLD,
category_allowlist=["ja"],
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(_MIXED_TEXT)
# Comparing results.
expected_result = LanguageDetectorResult(
[LanguageDetectorPrediction("ja", 0.481617)]
)
self._expect_language_detector_result_correct(
text_result, expected_result
)
def test_denylist_option(self):
# Creates detector.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _LanguageDetectorOptions(
base_options=base_options,
score_threshold=_SCORE_THRESHOLD,
category_denylist=["ja"],
)
with _LanguageDetector.create_from_options(options) as detector:
# Performs language detection on the input.
text_result = detector.detect(_MIXED_TEXT)
# Comparing results.
expected_result = LanguageDetectorResult(
[LanguageDetectorPrediction("zh", 0.505424)]
)
self._expect_language_detector_result_correct(
text_result, expected_result
)
if __name__ == "__main__":
absltest.main()

View File

@ -185,3 +185,20 @@ py_test(
"@com_google_protobuf//:protobuf_python",
],
)
py_test(
name = "face_aligner_test",
srcs = ["face_aligner_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:face_aligner",
"//mediapipe/tasks/python/vision/core:image_processing_options",
],
)

View File

@ -0,0 +1,190 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for face aligner."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import face_aligner
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
_BaseOptions = base_options_module.BaseOptions
_Rect = rect.Rect
_Image = image_module.Image
_FaceAligner = face_aligner.FaceAligner
_FaceAlignerOptions = face_aligner.FaceAlignerOptions
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL = 'face_landmarker_v2.task'
_LARGE_FACE_IMAGE = 'portrait.jpg'
_MODEL_IMAGE_SIZE = 256
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class FaceAlignerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL)
)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _FaceAligner.create_from_model_path(self.model_path) as aligner:
self.assertIsInstance(aligner, _FaceAligner)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceAlignerOptions(base_options=base_options)
with _FaceAligner.create_from_options(options) as aligner:
self.assertIsInstance(aligner, _FaceAligner)
def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite'
)
options = _FaceAlignerOptions(base_options=base_options)
_FaceAligner.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _FaceAlignerOptions(base_options=base_options)
aligner = _FaceAligner.create_from_options(options)
self.assertIsInstance(aligner, _FaceAligner)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
)
def test_align(self, model_file_type, image_file_name):
# Load the test image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, image_file_name)
)
)
# Creates aligner.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceAlignerOptions(base_options=base_options)
aligner = _FaceAligner.create_from_options(options)
# Performs face alignment on the input.
alignd_image = aligner.align(self.test_image)
self.assertIsInstance(alignd_image, _Image)
# Closes the aligner explicitly when the aligner is not used in
# a context.
aligner.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
)
def test_align_in_context(self, model_file_type, image_file_name):
# Load the test image.
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, image_file_name)
)
)
# Creates aligner.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _FaceAlignerOptions(base_options=base_options)
with _FaceAligner.create_from_options(options) as aligner:
# Performs face alignment on the input.
alignd_image = aligner.align(self.test_image)
self.assertIsInstance(alignd_image, _Image)
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
def test_align_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceAlignerOptions(base_options=base_options)
with _FaceAligner.create_from_options(options) as aligner:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
# Region-of-interest around the face.
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
image_processing_options = _ImageProcessingOptions(roi)
# Performs face alignment on the input.
alignd_image = aligner.align(test_image, image_processing_options)
self.assertIsInstance(alignd_image, _Image)
self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE)
self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE)
def test_align_succeeds_with_no_face_detected(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _FaceAlignerOptions(base_options=base_options)
with _FaceAligner.create_from_options(options) as aligner:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
)
)
# Region-of-interest that doesn't contain a human face.
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
image_processing_options = _ImageProcessingOptions(roi)
# Performs face alignment on the input.
alignd_image = aligner.align(test_image, image_processing_options)
self.assertIsNone(alignd_image)
if __name__ == '__main__':
absltest.main()

View File

@ -57,3 +57,22 @@ py_library(
"//mediapipe/tasks/python/text/core:base_text_task_api",
],
)
py_library(
name = "language_detector",
srcs = [
"language_detector.py",
],
deps = [
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classification_result",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/text/core:base_text_task_api",
],
)

View File

@ -14,9 +14,13 @@
"""MediaPipe Tasks Text API."""
import mediapipe.tasks.python.text.language_detector
import mediapipe.tasks.python.text.text_classifier
import mediapipe.tasks.python.text.text_embedder
LanguageDetector = language_detector.LanguageDetector
LanguageDetectorOptions = language_detector.LanguageDetectorOptions
LanguageDetectorResult = language_detector.LanguageDetectorResult
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier.TextClassifierOptions
TextClassifierResult = text_classifier.TextClassifierResult
@ -26,5 +30,6 @@ TextEmbedderResult = text_embedder.TextEmbedderResult
# Remove unnecessary modules to avoid duplication in API docs.
del mediapipe
del language_detector
del text_classifier
del text_embedder

View File

@ -0,0 +1,220 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe language detector task."""
import dataclasses
from typing import List, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
_ClassificationResult = classification_result_module.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = (
text_classifier_graph_options_pb2.TextClassifierGraphOptions
)
_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATIONS_STREAM_NAME = 'classifications_out'
_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS'
_TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
@dataclasses.dataclass
class LanguageDetectorResult:
@dataclasses.dataclass
class Detection:
"""A language code and its probability."""
# An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
# "ja"-Latn for Japanese (romaji).
language_code: str
probability: float
detections: List[Detection]
def _extract_language_detector_result(
classification_result: classification_result_module.ClassificationResult,
) -> LanguageDetectorResult:
"""Extracts a LanguageDetectorResult from a ClassificationResult."""
if len(classification_result.classifications) != 1:
raise ValueError(
'The LanguageDetector TextClassifierGraph should have exactly one '
'classification head.'
)
languages_and_scores = classification_result.classifications[0]
language_detector_result = LanguageDetectorResult([])
for category in languages_and_scores.categories:
if category.category_name is None:
raise ValueError(
'LanguageDetector ClassificationResult has a missing language code.'
)
prediction = LanguageDetectorResult.Detection(
category.category_name, category.score
)
language_detector_result.detections.append(prediction)
return language_detector_result
@dataclasses.dataclass
class LanguageDetectorOptions:
"""Options for the language detector task.
Attributes:
base_options: Base options for the language detector task.
display_names_locale: The locale to use for display names specified through
the TFLite Model Metadata.
max_results: The maximum number of top-scored classification results to
return.
score_threshold: Overrides the ones provided in the model metadata. Results
below this value are rejected.
category_allowlist: Allowlist of category names. If non-empty,
classification results whose category name is not in this set will be
filtered out. Duplicate or unknown category names are ignored. Mutually
exclusive with `category_denylist`.
category_denylist: Denylist of category names. If non-empty, classification
results whose category name is in this set will be filtered out. Duplicate
or unknown category names are ignored. Mutually exclusive with
`category_allowlist`.
"""
base_options: _BaseOptions
display_names_locale: Optional[str] = None
max_results: Optional[int] = None
score_threshold: Optional[float] = None
category_allowlist: Optional[List[str]] = None
category_denylist: Optional[List[str]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
classifier_options_proto = _ClassifierOptionsProto(
score_threshold=self.score_threshold,
category_allowlist=self.category_allowlist,
category_denylist=self.category_denylist,
display_names_locale=self.display_names_locale,
max_results=self.max_results,
)
return _TextClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto,
)
class LanguageDetector(base_text_task_api.BaseTextTaskApi):
"""Class that predicts the language of an input text.
This API expects a TFLite model with TFLite Model Metadata that contains the
mandatory (described below) input tensors, output tensor, and the language
codes in an AssociatedFile.
Input tensors:
(kTfLiteString)
- 1 input tensor that is scalar or has shape [1] containing the input
string.
Output tensor:
(kTfLiteFloat32)
- 1 output tensor of shape`[1 x N]` where `N` is the number of languages.
"""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'LanguageDetector':
"""Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`.
Args:
model_path: Path to the model.
Returns:
`LanguageDetector` object that's created from the model file and the
default `LanguageDetectorOptions`.
Raises:
ValueError: If failed to create `LanguageDetector` object from the
provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = LanguageDetectorOptions(base_options=base_options)
return cls.create_from_options(options)
@classmethod
def create_from_options(
cls, options: LanguageDetectorOptions
) -> 'LanguageDetector':
"""Creates the `LanguageDetector` object from language detector options.
Args:
options: Options for the language detector task.
Returns:
`LanguageDetector` object that's created from `options`.
Raises:
ValueError: If failed to create `LanguageDetector` object from
`LanguageDetectorOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
output_streams=[
':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME])
],
task_options=options,
)
return cls(task_info.generate_graph_config())
def detect(self, text: str) -> LanguageDetectorResult:
"""Predicts the language of the input `text`.
Args:
text: The input text.
Returns:
A `LanguageDetectorResult` object that contains a list of languages and
scores.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If language detection failed to run.
"""
output_packets = self._runner.process(
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}
)
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME])
)
classification_result = _ClassificationResult.create_from_pb2(
classification_result_proto
)
return _extract_language_detector_result(classification_result)

View File

@ -264,3 +264,22 @@ py_library(
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library(
name = "face_aligner",
srcs = [
"face_aligner.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)

View File

@ -15,6 +15,7 @@
"""MediaPipe Tasks Vision API."""
import mediapipe.tasks.python.vision.core
import mediapipe.tasks.python.vision.face_aligner
import mediapipe.tasks.python.vision.face_detector
import mediapipe.tasks.python.vision.face_landmarker
import mediapipe.tasks.python.vision.face_stylizer
@ -25,7 +26,10 @@ import mediapipe.tasks.python.vision.image_embedder
import mediapipe.tasks.python.vision.image_segmenter
import mediapipe.tasks.python.vision.interactive_segmenter
import mediapipe.tasks.python.vision.object_detector
import mediapipe.tasks.python.vision.pose_landmarker
FaceAligner = face_aligner.FaceAligner
FaceAlignerOptions = face_aligner.FaceAlignerOptions
FaceDetector = face_detector.FaceDetector
FaceDetectorOptions = face_detector.FaceDetectorOptions
FaceDetectorResult = face_detector.FaceDetectorResult
@ -41,6 +45,7 @@ GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
HandLandmarker = hand_landmarker.HandLandmarker
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
HandLandmarksConnections = hand_landmarker.HandLandmarksConnections
ImageClassifier = image_classifier.ImageClassifier
ImageClassifierOptions = image_classifier.ImageClassifierOptions
ImageClassifierResult = image_classifier.ImageClassifierResult
@ -54,10 +59,16 @@ InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest
ObjectDetector = object_detector.ObjectDetector
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
ObjectDetectorResult = object_detector.ObjectDetectorResult
PoseLandmarker = pose_landmarker.PoseLandmarker
PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions
PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult
PoseLandmarksConnections = pose_landmarker.PoseLandmarksConnections
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
# Remove unnecessary modules to avoid duplication in API docs.
del core
del face_aligner
del face_detector
del face_landmarker
del face_stylizer
@ -68,4 +79,5 @@ del image_embedder
del image_segmenter
del interactive_segmenter
del object_detector
del pose_landmarker
del mediapipe

View File

@ -0,0 +1,158 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe face aligner task."""
import dataclasses
from typing import Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_FaceStylizerGraphOptionsProto = (
face_stylizer_graph_options_pb2.FaceStylizerGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_FACE_ALIGNMENT_IMAGE_NAME = 'face_alignment'
_FACE_ALIGNMENT_IMAGE_TAG = 'FACE_ALIGNMENT'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph'
@dataclasses.dataclass
class FaceAlignerOptions:
"""Options for the face aligner task.
Attributes:
base_options: Base options for the face aligner task.
"""
base_options: _BaseOptions
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _FaceStylizerGraphOptionsProto:
"""Generates a FaceStylizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False
return _FaceStylizerGraphOptionsProto(base_options=base_options_proto)
class FaceAligner(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs face alignment on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'FaceAligner':
"""Creates a `FaceAligner` object from a face landmarker task bundle and the default `FaceAlignerOptions`.
Note that the created `FaceAligner` instance is in image mode, for
aligning one face on a single image input.
Args:
model_path: Path to the face landmarker task bundle.
Returns:
`FaceAligner` object that's created from the model file and the default
`FaceAlignerOptions`.
Raises:
ValueError: If failed to create `FaceAligner` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = FaceAlignerOptions(base_options=base_options)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls, options: FaceAlignerOptions) -> 'FaceAligner':
"""Creates the `FaceAligner` object from face aligner options.
Args:
options: Options for the face aligner task.
Returns:
`FaceAligner` object that's created from `options`.
Raises:
ValueError: If failed to create `FaceAligner` object from
`FaceAlignerOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_FACE_ALIGNMENT_IMAGE_TAG, _FACE_ALIGNMENT_IMAGE_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options,
)
return cls(
task_info.generate_graph_config(enable_flow_limiting=False),
_RunningMode.IMAGE,
None,
)
def align(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> image_module.Image:
"""Performs face alignment on the provided MediaPipe Image.
Only use this method when the FaceAligner is created with the image
running mode.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
The aligned face image. The aligned output image size is the same as the
model output size. None if no face is detected on the input image.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If face alignment failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, image
)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME: packet_creator.create_proto(
normalized_rect.to_pb2()
),
})
if output_packets[_FACE_ALIGNMENT_IMAGE_NAME].is_empty():
return None
return packet_getter.get_image(output_packets[_FACE_ALIGNMENT_IMAGE_NAME])

View File

@ -2939,7 +2939,7 @@ class FaceLandmarkerOptions:
Attributes:
base_options: Base options for the face landmarker task.
running_mode: The running mode of the task. Default to the image mode.
HandLandmarker has three running modes: 1) The image mode for detecting
FaceLandmarker has three running modes: 1) The image mode for detecting
face landmarks on single image inputs. 2) The video mode for detecting
face landmarks on the decoded frames of a video. 3) The live stream mode
for detecting face landmarks on the live stream of input data, such as

View File

@ -82,6 +82,65 @@ class HandLandmark(enum.IntEnum):
PINKY_TIP = 20
class HandLandmarksConnections:
"""The connections between hand landmarks."""
@dataclasses.dataclass
class Connection:
"""The connection class for hand landmarks."""
start: int
end: int
HAND_PALM_CONNECTIONS: List[Connection] = [
Connection(0, 1),
Connection(1, 5),
Connection(9, 13),
Connection(13, 17),
Connection(5, 9),
Connection(0, 17),
]
HAND_THUMB_CONNECTIONS: List[Connection] = [
Connection(1, 2),
Connection(2, 3),
Connection(3, 4),
]
HAND_INDEX_FINGER_CONNECTIONS: List[Connection] = [
Connection(5, 6),
Connection(6, 7),
Connection(7, 8),
]
HAND_MIDDLE_FINGER_CONNECTIONS: List[Connection] = [
Connection(9, 10),
Connection(10, 11),
Connection(11, 12),
]
HAND_RING_FINGER_CONNECTIONS: List[Connection] = [
Connection(13, 14),
Connection(14, 15),
Connection(15, 16),
]
HAND_PINKY_FINGER_CONNECTIONS: List[Connection] = [
Connection(17, 18),
Connection(18, 19),
Connection(19, 20),
]
HAND_CONNECTIONS: List[Connection] = (
HAND_PALM_CONNECTIONS +
HAND_THUMB_CONNECTIONS +
HAND_INDEX_FINGER_CONNECTIONS +
HAND_MIDDLE_FINGER_CONNECTIONS +
HAND_RING_FINGER_CONNECTIONS +
HAND_PINKY_FINGER_CONNECTIONS
)
@dataclasses.dataclass
class HandLandmarkerResult:
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.

View File

@ -88,7 +88,7 @@ class InteractiveSegmenterOptions:
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object."""
"""Generates an ImageSegmenterGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto()

View File

@ -132,6 +132,55 @@ def _build_landmarker_result(
return pose_landmarker_result
class PoseLandmarksConnections:
"""The connections between pose landmarks."""
@dataclasses.dataclass
class Connection:
"""The connection class for pose landmarks."""
start: int
end: int
POSE_LANDMARKS: List[Connection] = [
Connection(0, 1),
Connection(1, 2),
Connection(2, 3),
Connection(3, 7),
Connection(0, 4),
Connection(4, 5),
Connection(5, 6),
Connection(6, 8),
Connection(9, 10),
Connection(11, 12),
Connection(11, 13),
Connection(13, 15),
Connection(15, 17),
Connection(15, 19),
Connection(15, 21),
Connection(17, 19),
Connection(12, 14),
Connection(14, 16),
Connection(16, 18),
Connection(16, 20),
Connection(16, 22),
Connection(18, 20),
Connection(11, 23),
Connection(12, 24),
Connection(23, 24),
Connection(23, 25),
Connection(24, 26),
Connection(25, 27),
Connection(26, 28),
Connection(27, 29),
Connection(28, 30),
Connection(29, 31),
Connection(30, 32),
Connection(27, 31),
Connection(28, 32)
]
@dataclasses.dataclass
class PoseLandmarkerOptions:
"""Options for the pose landmarker task.

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/vector.h"
#include "mediapipe/util/color.pb.h"
#include "mediapipe/util/render_data.pb.h"
namespace mediapipe {
namespace {
@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) {
DrawGradientLine(annotation);
} else if (annotation.data_case() == RenderAnnotation::kArrow) {
DrawArrow(annotation);
} else if (annotation.data_case() == RenderAnnotation::kScribble) {
DrawScribble(annotation);
} else {
LOG(FATAL) << "Unknown annotation type: " << annotation.data_case();
}
@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) {
}
void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
const auto& point = annotation.point();
DrawPoint(annotation.point(), annotation);
}
void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point,
const RenderAnnotation& annotation) {
int x = -1;
int y = -1;
if (point.normalized()) {
@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) {
cv::circle(mat_image_, point_to_draw, thickness, color, -1);
}
void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) {
for (const RenderAnnotation::Point& point : annotation.scribble().point()) {
DrawPoint(point, annotation);
}
}
void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) {
int x_start = -1;
int y_start = -1;

View File

@ -96,6 +96,11 @@ class AnnotationRenderer {
// Draws a point on the image as described in the annotation.
void DrawPoint(const RenderAnnotation& annotation);
void DrawPoint(const RenderAnnotation::Point& point,
const RenderAnnotation& annotation);
// Draws scribbles on the image as described in the annotation.
void DrawScribble(const RenderAnnotation& annotation);
// Draws a line segment on the image as described in the annotation.
void DrawLine(const RenderAnnotation& annotation);

View File

@ -131,6 +131,10 @@ message RenderAnnotation {
optional Color color2 = 7;
}
message Scribble {
repeated Point point = 1;
}
message Arrow {
// The arrow head will be drawn at (x_end, y_end).
optional double x_start = 1;
@ -192,6 +196,7 @@ message RenderAnnotation {
RoundedRectangle rounded_rectangle = 9;
FilledRoundedRectangle filled_rounded_rectangle = 10;
GradientLine gradient_line = 14;
Scribble scribble = 15;
}
// Thickness for drawing the annotation.