diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc index 093f50d76..6f6f6f11c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_frame_buffer.cc @@ -95,7 +95,8 @@ absl::Status FrameBufferProcessor::Convert(const mediapipe::Image& input, static_cast(range_max) == 255); } - auto input_frame = input.GetGpuBuffer().GetReadView(); + auto input_frame = + input.GetGpuBuffer(/*upload_to_gpu=*/false).GetReadView(); const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); FrameBuffer::Dimension output_dimension{/*width=*/output_shape.dims[2], diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 22e6b0738..b6f50b840 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1285,12 +1285,14 @@ cc_library( srcs = ["flat_color_image_calculator.cc"], deps = [ ":flat_color_image_calculator_cc_proto", + "//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:ret_check", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/util/flat_color_image_calculator.cc b/mediapipe/calculators/util/flat_color_image_calculator.cc index 71d3582c5..f3b9c184c 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator.cc @@ -15,14 +15,13 @@ #include #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" #include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" -#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/util/color.pb.h" namespace mediapipe { @@ -32,6 +31,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Node; using ::mediapipe::api2::Output; +using ::mediapipe::api2::SideOutput; } // namespace // A calculator for generating an image filled with a single color. @@ -45,7 +45,8 @@ using ::mediapipe::api2::Output; // // Outputs: // IMAGE (Image) -// Image filled with the requested color. +// Image filled with the requested color. Can be either an output_stream +// or an output_side_packet. // // Example useage: // node { @@ -68,9 +69,10 @@ class FlatColorImageCalculator : public Node { public: static constexpr Input::Optional kInImage{"IMAGE"}; static constexpr Input::Optional kInColor{"COLOR"}; - static constexpr Output kOutImage{"IMAGE"}; + static constexpr Output::Optional kOutImage{"IMAGE"}; + static constexpr SideOutput::Optional kOutSideImage{"IMAGE"}; - MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage); + MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage, kOutSideImage); static absl::Status UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options(); @@ -81,6 +83,13 @@ class FlatColorImageCalculator : public Node { RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color()) << "Either set COLOR input stream, or set through options"; + RET_CHECK(kOutImage(cc).IsConnected() ^ kOutSideImage(cc).IsConnected()) + << "Set IMAGE either as output stream, or as output side packet"; + + RET_CHECK(!kOutSideImage(cc).IsConnected() || + (options.has_output_height() && options.has_output_width())) + << "Set size through options, when setting IMAGE as output side packet"; + return absl::OkStatus(); } @@ -88,6 +97,9 @@ class FlatColorImageCalculator : public Node { absl::Status Process(CalculatorContext* cc) override; private: + std::optional> CreateOutputFrame( + CalculatorContext* cc); + bool use_dimension_from_option_ = false; bool use_color_from_option_ = false; }; @@ -96,10 +108,31 @@ MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator); absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) { use_dimension_from_option_ = !kInImage(cc).IsConnected(); use_color_from_option_ = !kInColor(cc).IsConnected(); + + if (!kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutSideImage(cc).Set(Image(output_frame.value())); + } + } return absl::OkStatus(); } absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { + if (kOutImage(cc).IsConnected()) { + std::optional> output_frame = + CreateOutputFrame(cc); + if (output_frame.has_value()) { + kOutImage(cc).Send(Image(output_frame.value())); + } + } + + return absl::OkStatus(); +} + +std::optional> +FlatColorImageCalculator::CreateOutputFrame(CalculatorContext* cc) { const auto& options = cc->Options(); int output_height = -1; @@ -112,7 +145,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_height = input_image.height(); output_width = input_image.width(); } else { - return absl::OkStatus(); + return std::nullopt; } Color color; @@ -121,7 +154,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { } else if (!kInColor(cc).IsEmpty()) { color = kInColor(cc).Get(); } else { - return absl::OkStatus(); + return std::nullopt; } auto output_frame = std::make_shared(ImageFormat::SRGB, @@ -130,9 +163,7 @@ absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b())); - kOutImage(cc).Send(Image(output_frame)); - - return absl::OkStatus(); + return output_frame; } } // namespace mediapipe diff --git a/mediapipe/calculators/util/flat_color_image_calculator_test.cc b/mediapipe/calculators/util/flat_color_image_calculator_test.cc index 53c6de1b1..c09064bf2 100644 --- a/mediapipe/calculators/util/flat_color_image_calculator_test.cc +++ b/mediapipe/calculators/util/flat_color_image_calculator_test.cc @@ -113,6 +113,35 @@ TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) { } } +TEST(FlatColorImageCalculatorTest, ProducesOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + MP_ASSERT_OK(runner.Run()); + + const auto& image = runner.OutputSidePackets().Tag(kImageTag).Get(); + EXPECT_EQ(image.width(), 1); + EXPECT_EQ(image.height(), 1); + auto image_frame = image.GetImageFrameSharedPtr(); + const uint8_t* pixel_data = image_frame->PixelData(); + EXPECT_EQ(pixel_data[0], 100); + EXPECT_EQ(pixel_data[1], 200); + EXPECT_EQ(pixel_data[2], 255); +} + TEST(FlatColorImageCalculatorTest, FailureMissingDimension) { CalculatorRunner runner(R"pb( calculator: "FlatColorImageCalculator" @@ -206,5 +235,56 @@ TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) { HasSubstr("Either set COLOR input stream")); } +TEST(FlatColorImageCalculatorTest, FailureDuplicateOutputs) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + output_stream: "IMAGE:out_image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 1 + output_height: 1 + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + ASSERT_THAT( + runner.Run().message(), + HasSubstr("Set IMAGE either as output stream, or as output side packet")); +} + +TEST(FlatColorImageCalculatorTest, FailureSettingInputImageOnOutputSidePacket) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + output_side_packet: "IMAGE:out_packet" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Set size through options, when setting IMAGE as " + "output side packet")); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index 84b229d80..4c9e96b88 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -190,14 +190,16 @@ TEST(PaddingEffectGeneratorTest, ScaleToMultipleOfTwo) { double target_aspect_ratio = 0.5; int expect_width = 14; int expect_height = input_height; - auto test_frame = absl::make_unique(/*format=*/ImageFormat::SRGB, - input_width, input_height); + ImageFrame test_frame(/*format=*/ImageFormat::SRGB, input_width, + input_height); + cv::Mat mat = formats::MatView(&test_frame); + mat = cv::Scalar(0, 0, 0); - PaddingEffectGenerator generator(test_frame->Width(), test_frame->Height(), + PaddingEffectGenerator generator(test_frame.Width(), test_frame.Height(), target_aspect_ratio, /*scale_to_multiple_of_two=*/true); ImageFrame result_frame; - MP_ASSERT_OK(generator.Process(*test_frame, 0.3, 40, 0.0, &result_frame)); + MP_ASSERT_OK(generator.Process(test_frame, 0.3, 40, 0.0, &result_frame)); EXPECT_EQ(result_frame.Width(), expect_width); EXPECT_EQ(result_frame.Height(), expect_height); } diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h index ffb6362f3..936a3554e 100644 --- a/mediapipe/framework/formats/image.h +++ b/mediapipe/framework/formats/image.h @@ -113,11 +113,11 @@ class Image { #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // !MEDIAPIPE_DISABLE_GPU - // Get a GPU view. Automatically uploads from CPU if needed. - const mediapipe::GpuBuffer GetGpuBuffer() const { -#if !MEDIAPIPE_DISABLE_GPU - if (use_gpu_ == false) ConvertToGpu(); -#endif // !MEDIAPIPE_DISABLE_GPU + // Provides access to the underlying GpuBuffer storage. + // Automatically uploads from CPU to GPU if needed and requested through the + // `upload_to_gpu` argument. + const mediapipe::GpuBuffer GetGpuBuffer(bool upload_to_gpu = true) const { + if (!use_gpu_ && upload_to_gpu) ConvertToGpu(); return gpu_buffer_; } diff --git a/mediapipe/framework/port/drishti_proto_alias_rules.bzl b/mediapipe/framework/port/drishti_proto_alias_rules.bzl new file mode 100644 index 000000000..7df141cbe --- /dev/null +++ b/mediapipe/framework/port/drishti_proto_alias_rules.bzl @@ -0,0 +1,31 @@ +"""Rules implementation for mediapipe_proto_alias.bzl, do not load directly.""" + +def _copy_header_impl(ctx): + source = ctx.attr.source.replace("//", "").replace(":", "/") + files = [] + for dep in ctx.attr.deps: + for header in dep[CcInfo].compilation_context.direct_headers: + if (header.short_path == source): + files.append(header) + if len(files) != 1: + fail("Expected exactly 1 source, got ", str(files)) + dest_file = ctx.actions.declare_file(ctx.attr.filename) + + # Use expand_template() with no substitutions as a simple copier. + ctx.actions.expand_template( + template = files[0], + output = dest_file, + substitutions = {}, + ) + return [DefaultInfo(files = depset([dest_file]))] + +copy_header = rule( + implementation = _copy_header_impl, + attrs = { + "filename": attr.string(), + "source": attr.string(), + "deps": attr.label_list(providers = [CcInfo]), + }, + output_to_genfiles = True, + outputs = {"out": "%{filename}"}, +) diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index fbdcf8c9e..4ae0bb607 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -791,6 +791,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@stblib//:stb_image", "@stblib//:stb_image_write", ], diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 5642941e9..64b5072c5 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -26,6 +26,7 @@ #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "mediapipe/framework/calculator.pb.h" @@ -311,6 +312,13 @@ std::unique_ptr LoadTestPng(absl::string_view path, // Returns the path to the output if successful. absl::StatusOr SavePngTestOutput( const mediapipe::ImageFrame& image, absl::string_view prefix) { + absl::flat_hash_set supported_formats = { + ImageFormat::GRAY8, ImageFormat::SRGB, ImageFormat::SRGBA, + ImageFormat::LAB8, ImageFormat::SBGRA}; + if (!supported_formats.contains(image.Format())) { + return absl::CancelledError( + absl::StrFormat("Format %d can not be saved to PNG.", image.Format())); + } std::string now_string = absl::FormatTime(absl::Now()); std::string output_relative_path = absl::StrCat(prefix, "_", now_string, ".png"); diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index eac669786..e3eb3a651 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -59,7 +59,9 @@ class ObjectDetectorModel(tf.keras.Model): self._num_classes = num_classes self._model = self._build_model() checkpoint_folder = self._model_spec.downloaded_files.get_path() - checkpoint_file = os.path.join(checkpoint_folder, 'ckpt-277200') + checkpoint_file = os.path.join( + checkpoint_folder, self._model_spec.checkpoint_name + ) self.load_checkpoint(checkpoint_file) self._model.summary() self.loss_trackers = [ @@ -80,7 +82,10 @@ class ObjectDetectorModel(tf.keras.Model): num_scales=3, aspect_ratios=[0.5, 1.0, 2.0], anchor_size=3 ), backbone=configs.backbones.Backbone( - type='mobilenet', mobilenet=configs.backbones.MobileNet() + type='mobilenet', + mobilenet=configs.backbones.MobileNet( + model_id=self._model_spec.model_id + ), ), decoder=configs.decoders.Decoder( type='fpn', diff --git a/mediapipe/model_maker/python/vision/object_detector/model_spec.py b/mediapipe/model_maker/python/vision/object_detector/model_spec.py index 2ce838c71..9c89c4ed0 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_spec.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_spec.py @@ -26,6 +26,12 @@ MOBILENET_V2_FILES = file_util.DownloadedFiles( is_folder=True, ) +MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', + is_folder=True, +) + @dataclasses.dataclass class ModelSpec(object): @@ -38,13 +44,25 @@ class ModelSpec(object): stddev_rgb = (127.5,) downloaded_files: file_util.DownloadedFiles + checkpoint_name: str input_image_shape: List[int] + model_id: str mobilenet_v2_spec = functools.partial( ModelSpec, downloaded_files=MOBILENET_V2_FILES, + checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], + model_id='MobileNetV2', +) + +mobilenet_multi_avg_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[256, 256, 3], + model_id='MobileNetMultiAVG', ) @@ -53,6 +71,7 @@ class SupportedModels(enum.Enum): """Predefined object detector model specs supported by Model Maker.""" MOBILENET_V2 = mobilenet_v2_spec + MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec': diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 82d4ea21b..a45c91633 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -93,3 +93,8 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "transformer_params_proto", + srcs = ["transformer_params.proto"], +) diff --git a/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto new file mode 100644 index 000000000..8c1daf277 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/transformer_params.proto @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "TransformerParametersProto"; + +// The parameters of transformer (https://arxiv.org/pdf/1706.03762.pdf) +message TransformerParameters { + // Batch size of tensors. + int32 batch_size = 1; + + // Maximum sequence length of the input/output tensor. + int32 max_seq_length = 2; + + // Embedding dimension (or model dimension), `d_model` in the paper. + // `d_k` == `d_v` == `d_model`/`h`. + int32 embedding_dim = 3; + + // Hidden dimension used in the feedforward layer, `d_ff` in the paper. + int32 hidden_dimension = 4; + + // Head dimension, `d_k` or `d_v` in the paper. + int32 head_dimension = 5; + + // Number of heads, `h` in the paper. + int32 num_heads = 6; + + // Number of stacked transformers, `N` in the paper. + int32 num_stacks = 7; +} diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc index 8d83ac2c8..2e5f7e416 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector_graph.cc @@ -242,7 +242,7 @@ class FaceDetectorGraph : public core::ModelTaskGraph { auto matrix = preprocessing.Out(kMatrixTag); auto image_size = preprocessing.Out(kImageSizeTag); - // Face detection model inferece. + // Face detection model inference. auto& inference = AddInference( model_resources, subgraph_options.base_options().acceleration(), graph); preprocessed_tensors >> inference.In(kTensorsTag); diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc index 7c4e6f138..cb49ef59d 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc @@ -199,7 +199,9 @@ void ConfigureTensorsToImageCalculator( // STYLIZED_IMAGE - mediapipe::Image // The face stylization output image. // FACE_ALIGNMENT - mediapipe::Image -// The face alignment output image. +// The aligned face image that is fed to the face stylization model to +// perform stylization. Also useful for preparing face stylization training +// data. // IMAGE - mediapipe::Image // The input image that the face landmarker runs on and has the pixel data // stored on the target storage (CPU vs GPU). @@ -211,6 +213,7 @@ void ConfigureTensorsToImageCalculator( // input_stream: "NORM_RECT:norm_rect" // output_stream: "IMAGE:image_out" // output_stream: "STYLIZED_IMAGE:stylized_image" +// output_stream: "FACE_ALIGNMENT:face_alignment_image" // options { // [mediapipe.tasks.vision.face_stylizer.proto.FaceStylizerGraphOptions.ext] // { @@ -248,7 +251,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph { ->mutable_face_landmarker_graph_options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - const ModelResources* face_stylizer_model_resources; + const ModelResources* face_stylizer_model_resources = nullptr; if (output_stylized) { ASSIGN_OR_RETURN( const auto* model_resources, @@ -332,7 +335,7 @@ class FaceStylizerGraph : public core::ModelTaskGraph { auto face_rect = face_to_rect.Out(kNormRectTag); std::optional> face_alignment; - // Output face alignment only. + // Output aligned face only. // In this case, the face stylization model inference is not required. // However, to keep consistent with the inference preprocessing steps, the // ImageToTensorCalculator is still used to perform image rotation, diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index af2a3f50c..c0d89c87d 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -60,6 +61,8 @@ constexpr absl::string_view kNormRectTag{"NORM_RECT"}; constexpr absl::string_view kSubgraphTypeName{ "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"}; +using components::containers::NormalizedKeypoint; + using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; @@ -115,7 +118,7 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { case RegionOfInterest::Format::kUnspecified: return absl::InvalidArgumentError( "RegionOfInterest format not specified"); - case RegionOfInterest::Format::kKeyPoint: + case RegionOfInterest::Format::kKeyPoint: { RET_CHECK(roi.keypoint.has_value()); auto* annotation = result.add_render_annotations(); annotation->mutable_color()->set_r(255); @@ -124,6 +127,19 @@ absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { point->set_x(roi.keypoint->x); point->set_y(roi.keypoint->y); return result; + } + case RegionOfInterest::Format::kScribble: { + RET_CHECK(roi.scribble.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + for (const NormalizedKeypoint& keypoint : *(roi.scribble)) { + auto* point = annotation->mutable_scribble()->add_point(); + point->set_normalized(true); + point->set_x(keypoint.x); + point->set_y(keypoint.y); + } + return result; + } } return absl::UnimplementedError("Unrecognized format"); } diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h index ad4a238df..ad8a558df 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -53,6 +53,7 @@ struct RegionOfInterest { enum class Format { kUnspecified = 0, // Format not specified. kKeyPoint = 1, // Using keypoint to represent ROI. + kScribble = 2, // Using scribble to represent ROI. }; // Specifies the format used to specify the region-of-interest. Note that @@ -61,8 +62,13 @@ struct RegionOfInterest { Format format = Format::kUnspecified; // Represents the ROI in keypoint format, this should be non-nullopt if - // `format` is `KEYPOINT`. + // `format` is `kKeyPoint`. std::optional keypoint; + + // Represents the ROI in scribble format, this should be non-nullopt if + // `format` is `kScribble`. + std::optional> + scribble; }; // Performs interactive segmentation on images. diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index 443247aea..16d065f61 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -18,9 +18,12 @@ limitations under the License. #include #include #include +#include +#include #include "absl/flags/flag.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" @@ -179,22 +182,46 @@ TEST_F(CreateFromOptionsTest, FailsWithNeitherOutputSet) { struct InteractiveSegmenterTestParams { std::string test_name; RegionOfInterest::Format format; - NormalizedKeypoint roi; + std::variant> roi; absl::string_view golden_mask_file; float similarity_threshold; }; -using SucceedSegmentationWithRoi = - ::testing::TestWithParam; +class SucceedSegmentationWithRoi + : public ::testing::TestWithParam { + public: + absl::StatusOr TestParamsToTaskOptions() { + const InteractiveSegmenterTestParams& params = GetParam(); + + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + switch (params.format) { + case (RegionOfInterest::Format::kKeyPoint): { + interaction_roi.keypoint = std::get(params.roi); + break; + } + case (RegionOfInterest::Format::kScribble): { + interaction_roi.scribble = + std::get>(params.roi); + break; + } + default: { + return absl::InvalidArgumentError("Unknown ROI format"); + } + } + + return interaction_roi; + } +}; TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -220,13 +247,13 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { } TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { - const auto& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN(RegionOfInterest interaction_roi, + TestParamsToTaskOptions()); + const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); - RegionOfInterest interaction_roi; - interaction_roi.format = params.format; - interaction_roi.keypoint = params.roi; auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kPtmModel); @@ -253,11 +280,23 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { INSTANTIATE_TEST_SUITE_P( SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, ::testing::ValuesIn( - {{"PointToDog1", RegionOfInterest::Format::kKeyPoint, + {// Keypoint input. + {"PointToDog1", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, {"PointToDog2", RegionOfInterest::Format::kKeyPoint, NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, - kGoldenMaskSimilarity}}), + kGoldenMaskSimilarity}, + // Scribble input. + {"ScribbleToDog1", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.44, 0.70}, + NormalizedKeypoint{0.44, 0.71}, + NormalizedKeypoint{0.44, 0.72}}, + kCatsAndDogsMaskDog1, 0.84f}, + {"ScribbleToDog2", RegionOfInterest::Format::kScribble, + std::vector{NormalizedKeypoint{0.66, 0.66}, + NormalizedKeypoint{0.66, 0.67}, + NormalizedKeypoint{0.66, 0.68}}, + kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}), [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index 826de5ec4..7889212e8 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -108,9 +108,18 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_model_asset(), is_copy); } - pose_detector_graph_options->mutable_base_options() - ->mutable_acceleration() - ->CopyFrom(options->base_options().acceleration()); + if (options->base_options().acceleration().has_gpu()) { + core::proto::Acceleration gpu_accel; + gpu_accel.mutable_gpu()->set_use_advanced_gpu_api(true); + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(gpu_accel); + + } else { + pose_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + } pose_detector_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); auto* pose_landmarks_detector_graph_options = diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m index c245478db..c61cf0b39 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPDetection.m @@ -28,7 +28,12 @@ return self; } -// TODO: Implement hash +- (NSUInteger)hash { + NSUInteger nonNullPropertiesHash = + @(self.location.x).hash ^ @(self.location.y).hash ^ @(self.score).hash; + + return self.label ? nonNullPropertiesHash ^ self.label.hash : nonNullPropertiesHash; +} - (BOOL)isEqual:(nullable id)object { if (!object) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 5be0e233f..a2dbe351a 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -180,6 +180,7 @@ android_library( srcs = [ "poselandmarker/PoseLandmarker.java", "poselandmarker/PoseLandmarkerResult.java", + "poselandmarker/PoseLandmarksConnections.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -212,6 +213,7 @@ android_library( "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", + "handlandmarker/HandLandmarksConnections.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 070806522..5964cef2c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -77,11 +77,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } return runner.process(inputPackets); } @@ -105,11 +107,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -133,11 +137,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - inputPackets.put( - normRectStreamName, - runner - .getPacketCreator() - .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + if (!normRectStreamName.isEmpty()) { + inputPackets.put( + normRectStreamName, + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); + } runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java new file mode 100644 index 000000000..c60923840 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarksConnections.java @@ -0,0 +1,105 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.handlandmarker; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Hand landmarks connection constants. */ +public final class HandLandmarksConnections { + + /** Value class representing hand landmarks connection. */ + @AutoValue + public abstract static class Connection { + static Connection create(int start, int end) { + return new AutoValue_HandLandmarksConnections_Connection(start, end); + } + + public abstract int start(); + + public abstract int end(); + } + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_PALM_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(0, 1), + Connection.create(0, 5), + Connection.create(9, 13), + Connection.create(13, 17), + Connection.create(5, 9), + Connection.create(0, 17)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_THUMB_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(1, 2), Connection.create(2, 3), Connection.create(3, 4)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_INDEX_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(5, 6), Connection.create(6, 7), Connection.create(7, 8)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_MIDDLE_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(9, 10), Connection.create(10, 11), Connection.create(11, 12)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_RING_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(13, 14), + Connection.create(14, 15), + Connection.create(15, 16)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_PINKY_FINGER_CONNECTIONS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(17, 18), + Connection.create(18, 19), + Connection.create(19, 20)))); + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set HAND_CONNECTIONS = + Collections.unmodifiableSet( + Stream.of( + HAND_PALM_CONNECTIONS.stream(), + HAND_THUMB_CONNECTIONS.stream(), + HAND_INDEX_FINGER_CONNECTIONS.stream(), + HAND_MIDDLE_FINGER_CONNECTIONS.stream(), + HAND_RING_FINGER_CONNECTIONS.stream(), + HAND_PINKY_FINGER_CONNECTIONS.stream()) + .flatMap(i -> i) + .collect(Collectors.toSet())); + + private HandLandmarksConnections() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 52a5f2a67..e9ff1f2b5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { /** The Region-Of-Interest (ROI) to interact with. */ public static class RegionOfInterest { private NormalizedKeypoint keypoint; + private List scribble; private RegionOfInterest() {} @@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { roi.keypoint = keypoint; return roi; } + + /** + * Creates a {@link RegionOfInterest} instance representing scribbles over the object that the + * user wants to segment. + */ + public static RegionOfInterest create(List scribble) { + RegionOfInterest roi = new RegionOfInterest(); + roi.scribble = scribble; + return roi; + } } /** @@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { .setX(roi.keypoint.x()) .setY(roi.keypoint.y()))) .build(); + } else if (roi.scribble != null) { + RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder(); + for (NormalizedKeypoint p : roi.scribble) { + scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y())); + } + + return builder + .addRenderAnnotations( + RenderAnnotation.newBuilder() + .setColor(Color.newBuilder().setR(255)) + .setScribble(scribbleBuilder)) + .build(); } throw new IllegalArgumentException( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java new file mode 100644 index 000000000..9be6a9aeb --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarksConnections.java @@ -0,0 +1,80 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.poselandmarker; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** Pose landmarks connection constants. */ +public final class PoseLandmarksConnections { + + /** Value class representing pose landmarks connection. */ + @AutoValue + public abstract static class Connection { + static Connection create(int start, int end) { + return new AutoValue_PoseLandmarksConnections_Connection(start, end); + } + + public abstract int start(); + + public abstract int end(); + } + + @SuppressWarnings("ConstantCaseForConstants") + public static final Set POSE_LANDMARKS = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList( + Connection.create(0, 1), + Connection.create(1, 2), + Connection.create(2, 3), + Connection.create(3, 7), + Connection.create(0, 4), + Connection.create(4, 5), + Connection.create(5, 6), + Connection.create(6, 8), + Connection.create(9, 10), + Connection.create(11, 12), + Connection.create(11, 13), + Connection.create(13, 15), + Connection.create(15, 17), + Connection.create(15, 19), + Connection.create(15, 21), + Connection.create(17, 19), + Connection.create(12, 14), + Connection.create(14, 16), + Connection.create(16, 18), + Connection.create(16, 20), + Connection.create(16, 22), + Connection.create(18, 20), + Connection.create(11, 23), + Connection.create(12, 24), + Connection.create(23, 24), + Connection.create(23, 25), + Connection.create(24, 26), + Connection.create(25, 27), + Connection.create(26, 28), + Connection.create(27, 29), + Connection.create(28, 30), + Connection.create(29, 31), + Connection.create(30, 32), + Connection.create(27, 31), + Connection.create(28, 32)))); + + private PoseLandmarksConnections() {} +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 506036ba2..a534970f7 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult; import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions; import java.io.InputStream; +import java.util.ArrayList; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses; /** Test for {@link InteractiveSegmenter}. */ @RunWith(Suite.class) @SuiteClasses({ - InteractiveSegmenterTest.General.class, + InteractiveSegmenterTest.KeypointRoi.class, + InteractiveSegmenterTest.ScribbleRoi.class, }) public class InteractiveSegmenterTest { private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite"; @@ -44,7 +46,7 @@ public class InteractiveSegmenterTest { private static final int MAGNIFICATION_FACTOR = 10; @RunWith(AndroidJUnit4.class) - public static final class General extends InteractiveSegmenterTest { + public static final class KeypointRoi extends InteractiveSegmenterTest { @Test public void segment_successWithCategoryMask() throws Exception { final String inputImageName = CATS_AND_DOGS_IMAGE; @@ -86,6 +88,57 @@ public class InteractiveSegmenterTest { } } + @RunWith(AndroidJUnit4.class) + public static final class ScribbleRoi extends InteractiveSegmenterTest { + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(false) + .setOutputCategoryMask(true) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MPImage image = getImageFromAsset(inputImageName); + ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = + imageSegmenter.segment(getImageFromAsset(inputImageName), roi); + assertThat(actualResult.confidenceMasks().isPresent()).isTrue(); + List confidenceMasks = actualResult.confidenceMasks().get(); + assertThat(confidenceMasks.size()).isEqualTo(2); + } + } + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); diff --git a/mediapipe/tasks/python/components/containers/landmark_detection_result.py b/mediapipe/tasks/python/components/containers/landmark_detection_result.py index c60ad850c..fdb719b92 100644 --- a/mediapipe/tasks/python/components/containers/landmark_detection_result.py +++ b/mediapipe/tasks/python/components/containers/landmark_detection_result.py @@ -39,9 +39,11 @@ _Landmark = landmark_module.Landmark class LandmarksDetectionResult: """Represents the landmarks detection result. - Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A - list of `Category` objects. world_landmarks : A list of `Landmark` objects. - rect : A `NormalizedRect` object. + Attributes: + landmarks: A list of `NormalizedLandmark` objects. + categories: A list of `Category` objects. + world_landmarks: A list of `Landmark` objects. + rect: A `NormalizedRect` object. """ landmarks: Optional[List[_NormalizedLandmark]] diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD index 44f584b66..be352c84d 100644 --- a/mediapipe/tasks/python/test/text/BUILD +++ b/mediapipe/tasks/python/test/text/BUILD @@ -49,3 +49,18 @@ py_test( "//mediapipe/tasks/python/text:text_embedder", ], ) + +py_test( + name = "language_detector_test", + srcs = ["language_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:language_detector", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:language_detector", + ], +) diff --git a/mediapipe/tasks/python/test/text/language_detector_test.py b/mediapipe/tasks/python/test/text/language_detector_test.py new file mode 100644 index 000000000..45a04d564 --- /dev/null +++ b/mediapipe/tasks/python/test/text/language_detector_test.py @@ -0,0 +1,228 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for language detector.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import language_detector + +LanguageDetectorResult = language_detector.LanguageDetectorResult +LanguageDetectorPrediction = ( + language_detector.LanguageDetectorResult.Detection +) +_BaseOptions = base_options_module.BaseOptions +_Category = category.Category +_Classifications = classification_result_module.Classifications +_LanguageDetector = language_detector.LanguageDetector +_LanguageDetectorOptions = language_detector.LanguageDetectorOptions + +_LANGUAGE_DETECTOR_MODEL = "language_detector.tflite" +_TEST_DATA_DIR = "mediapipe/tasks/testdata/text" + +_SCORE_THRESHOLD = 0.3 +_EN_TEXT = "To be, or not to be, that is the question" +_EN_EXPECTED_RESULT = LanguageDetectorResult( + [LanguageDetectorPrediction("en", 0.999856)] +) +_FR_TEXT = ( + "Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent." +) +_FR_EXPECTED_RESULT = LanguageDetectorResult( + [LanguageDetectorPrediction("fr", 0.999781)] +) +_RU_TEXT = "это какой-то английский язык" +_RU_EXPECTED_RESULT = LanguageDetectorResult( + [LanguageDetectorPrediction("ru", 0.993362)] +) +_MIXED_TEXT = "分久必合合久必分" +_MIXED_EXPECTED_RESULT = LanguageDetectorResult([ + LanguageDetectorPrediction("zh", 0.505424), + LanguageDetectorPrediction("ja", 0.481617), +]) +_TOLERANCE = 1e-6 + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class LanguageDetectorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LANGUAGE_DETECTOR_MODEL) + ) + + def _expect_language_detector_result_correct( + self, + actual_result: LanguageDetectorResult, + expect_result: LanguageDetectorResult, + ): + for i, prediction in enumerate(actual_result.detections): + expected_prediction = expect_result.detections[i] + self.assertEqual( + prediction.language_code, + expected_prediction.language_code, + ) + self.assertAlmostEqual( + prediction.probability, + expected_prediction.probability, + delta=_TOLERANCE, + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _LanguageDetector.create_from_model_path(self.model_path) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions(base_options=base_options) + with _LanguageDetector.create_from_options(options) as detector: + self.assertIsInstance(detector, _LanguageDetector) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, "Unable to open file at /path/to/invalid/model.tflite" + ): + base_options = _BaseOptions( + model_asset_path="/path/to/invalid/model.tflite" + ) + options = _LanguageDetectorOptions(base_options=base_options) + _LanguageDetector.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, "rb") as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _LanguageDetectorOptions(base_options=base_options) + detector = _LanguageDetector.create_from_options(options) + self.assertIsInstance(detector, _LanguageDetector) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _MIXED_TEXT, _MIXED_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT), + ) + def test_detect(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, "rb") as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError("model_file_type is invalid.") + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + detector = _LanguageDetector.create_from_options(options) + + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct(text_result, expected_result) + # Closes the detector explicitly when the detector is not used in + # a context. + detector.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _EN_TEXT, _EN_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _FR_TEXT, _FR_EXPECTED_RESULT), + (ModelFileType.FILE_NAME, _RU_TEXT, _RU_EXPECTED_RESULT), + (ModelFileType.FILE_CONTENT, _MIXED_TEXT, _MIXED_EXPECTED_RESULT), + ) + def test_detect_in_context(self, model_file_type, text, expected_result): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, "rb") as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError("model_file_type is invalid.") + + options = _LanguageDetectorOptions( + base_options=base_options, score_threshold=_SCORE_THRESHOLD + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(text) + # Comparing results. + self._expect_language_detector_result_correct( + text_result, expected_result + ) + + def test_allowlist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, + score_threshold=_SCORE_THRESHOLD, + category_allowlist=["ja"], + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [LanguageDetectorPrediction("ja", 0.481617)] + ) + self._expect_language_detector_result_correct( + text_result, expected_result + ) + + def test_denylist_option(self): + # Creates detector. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _LanguageDetectorOptions( + base_options=base_options, + score_threshold=_SCORE_THRESHOLD, + category_denylist=["ja"], + ) + with _LanguageDetector.create_from_options(options) as detector: + # Performs language detection on the input. + text_result = detector.detect(_MIXED_TEXT) + # Comparing results. + expected_result = LanguageDetectorResult( + [LanguageDetectorPrediction("zh", 0.505424)] + ) + self._expect_language_detector_result_correct( + text_result, expected_result + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index e55e1b572..d555402b8 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -185,3 +185,20 @@ py_test( "@com_google_protobuf//:protobuf_python", ], ) + +py_test( + name = "face_aligner_test", + srcs = ["face_aligner_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_aligner", + "//mediapipe/tasks/python/vision/core:image_processing_options", + ], +) diff --git a/mediapipe/tasks/python/test/vision/face_aligner_test.py b/mediapipe/tasks/python/test/vision/face_aligner_test.py new file mode 100644 index 000000000..324bd0359 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_aligner_test.py @@ -0,0 +1,190 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face aligner.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import rect +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_aligner +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module + +_BaseOptions = base_options_module.BaseOptions +_Rect = rect.Rect +_Image = image_module.Image +_FaceAligner = face_aligner.FaceAligner +_FaceAlignerOptions = face_aligner.FaceAlignerOptions +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_MODEL = 'face_landmarker_v2.task' +_LARGE_FACE_IMAGE = 'portrait.jpg' +_MODEL_IMAGE_SIZE = 256 +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceAlignerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _MODEL) + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceAligner.create_from_model_path(self.model_path) as aligner: + self.assertIsInstance(aligner, _FaceAligner) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + self.assertIsInstance(aligner, _FaceAligner) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite' + ) + options = _FaceAlignerOptions(base_options=base_options) + _FaceAligner.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceAlignerOptions(base_options=base_options) + aligner = _FaceAligner.create_from_options(options) + self.assertIsInstance(aligner, _FaceAligner) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE), + ) + def test_align(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name) + ) + ) + # Creates aligner. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceAlignerOptions(base_options=base_options) + aligner = _FaceAligner.create_from_options(options) + + # Performs face alignment on the input. + alignd_image = aligner.align(self.test_image) + self.assertIsInstance(alignd_image, _Image) + # Closes the aligner explicitly when the aligner is not used in + # a context. + aligner.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE), + (ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE), + ) + def test_align_in_context(self, model_file_type, image_file_name): + # Load the test image. + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, image_file_name) + ) + ) + # Creates aligner. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Performs face alignment on the input. + alignd_image = aligner.align(self.test_image) + self.assertIsInstance(alignd_image, _Image) + self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE) + + def test_align_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest around the face. + roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face alignment on the input. + alignd_image = aligner.align(test_image, image_processing_options) + self.assertIsInstance(alignd_image, _Image) + self.assertEqual(alignd_image.width, _MODEL_IMAGE_SIZE) + self.assertEqual(alignd_image.height, _MODEL_IMAGE_SIZE) + + def test_align_succeeds_with_no_face_detected(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceAlignerOptions(base_options=base_options) + with _FaceAligner.create_from_options(options) as aligner: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE) + ) + ) + # Region-of-interest that doesn't contain a human face. + roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2) + image_processing_options = _ImageProcessingOptions(roi) + # Performs face alignment on the input. + alignd_image = aligner.align(test_image, image_processing_options) + self.assertIsNone(alignd_image) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index cdc41672b..fa5e70b63 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -57,3 +57,22 @@ py_library( "//mediapipe/tasks/python/text/core:base_text_task_api", ], ) + +py_library( + name = "language_detector", + srcs = [ + "language_detector.py", + ], + deps = [ + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classification_result", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py index 5aa221c33..66c62cafc 100644 --- a/mediapipe/tasks/python/text/__init__.py +++ b/mediapipe/tasks/python/text/__init__.py @@ -14,9 +14,13 @@ """MediaPipe Tasks Text API.""" +import mediapipe.tasks.python.text.language_detector import mediapipe.tasks.python.text.text_classifier import mediapipe.tasks.python.text.text_embedder +LanguageDetector = language_detector.LanguageDetector +LanguageDetectorOptions = language_detector.LanguageDetectorOptions +LanguageDetectorResult = language_detector.LanguageDetectorResult TextClassifier = text_classifier.TextClassifier TextClassifierOptions = text_classifier.TextClassifierOptions TextClassifierResult = text_classifier.TextClassifierResult @@ -26,5 +30,6 @@ TextEmbedderResult = text_embedder.TextEmbedderResult # Remove unnecessary modules to avoid duplication in API docs. del mediapipe +del language_detector del text_classifier del text_embedder diff --git a/mediapipe/tasks/python/text/language_detector.py b/mediapipe/tasks/python/text/language_detector.py new file mode 100644 index 000000000..6b27a458a --- /dev/null +++ b/mediapipe/tasks/python/text/language_detector.py @@ -0,0 +1,220 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe language detector task.""" + +import dataclasses +from typing import List, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 +from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classification_result as classification_result_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +_ClassificationResult = classification_result_module.ClassificationResult +_BaseOptions = base_options_module.BaseOptions +_TextClassifierGraphOptionsProto = ( + text_classifier_graph_options_pb2.TextClassifierGraphOptions +) +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions +_TaskInfo = task_info_module.TaskInfo + +_CLASSIFICATIONS_STREAM_NAME = 'classifications_out' +_CLASSIFICATIONS_TAG = 'CLASSIFICATIONS' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' + + +@dataclasses.dataclass +class LanguageDetectorResult: + + @dataclasses.dataclass + class Detection: + """A language code and its probability.""" + + # An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + # "ja"-Latn for Japanese (romaji). + language_code: str + probability: float + + detections: List[Detection] + + +def _extract_language_detector_result( + classification_result: classification_result_module.ClassificationResult, +) -> LanguageDetectorResult: + """Extracts a LanguageDetectorResult from a ClassificationResult.""" + if len(classification_result.classifications) != 1: + raise ValueError( + 'The LanguageDetector TextClassifierGraph should have exactly one ' + 'classification head.' + ) + languages_and_scores = classification_result.classifications[0] + language_detector_result = LanguageDetectorResult([]) + for category in languages_and_scores.categories: + if category.category_name is None: + raise ValueError( + 'LanguageDetector ClassificationResult has a missing language code.' + ) + prediction = LanguageDetectorResult.Detection( + category.category_name, category.score + ) + language_detector_result.detections.append(prediction) + return language_detector_result + + +@dataclasses.dataclass +class LanguageDetectorOptions: + """Options for the language detector task. + + Attributes: + base_options: Base options for the language detector task. + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, + classification results whose category name is not in this set will be + filtered out. Duplicate or unknown category names are ignored. Mutually + exclusive with `category_denylist`. + category_denylist: Denylist of category names. If non-empty, classification + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + + base_options: _BaseOptions + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextClassifierGraphOptionsProto: + """Generates an TextClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + classifier_options_proto = _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results, + ) + + return _TextClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto, + ) + + +class LanguageDetector(base_text_task_api.BaseTextTaskApi): + """Class that predicts the language of an input text. + + This API expects a TFLite model with TFLite Model Metadata that contains the + mandatory (described below) input tensors, output tensor, and the language + codes in an AssociatedFile. + + Input tensors: + (kTfLiteString) + - 1 input tensor that is scalar or has shape [1] containing the input + string. + Output tensor: + (kTfLiteFloat32) + - 1 output tensor of shape`[1 x N]` where `N` is the number of languages. + """ + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'LanguageDetector': + """Creates an `LanguageDetector` object from a TensorFlow Lite model and the default `LanguageDetectorOptions`. + + Args: + model_path: Path to the model. + + Returns: + `LanguageDetector` object that's created from the model file and the + default `LanguageDetectorOptions`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from the + provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = LanguageDetectorOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options( + cls, options: LanguageDetectorOptions + ) -> 'LanguageDetector': + """Creates the `LanguageDetector` object from language detector options. + + Args: + options: Options for the language detector task. + + Returns: + `LanguageDetector` object that's created from `options`. + + Raises: + ValueError: If failed to create `LanguageDetector` object from + `LanguageDetectorOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], + output_streams=[ + ':'.join([_CLASSIFICATIONS_TAG, _CLASSIFICATIONS_STREAM_NAME]) + ], + task_options=options, + ) + return cls(task_info.generate_graph_config()) + + def detect(self, text: str) -> LanguageDetectorResult: + """Predicts the language of the input `text`. + + Args: + text: The input text. + + Returns: + A `LanguageDetectorResult` object that contains a list of languages and + scores. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If language detection failed to run. + """ + output_packets = self._runner.process( + {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)} + ) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto(output_packets[_CLASSIFICATIONS_STREAM_NAME]) + ) + + classification_result = _ClassificationResult.create_from_pb2( + classification_result_proto + ) + return _extract_language_detector_result(classification_result) diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index d0c97434f..dcd28dcf5 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -264,3 +264,22 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_aligner", + srcs = [ + "face_aligner.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_py_pb2", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 75a8bd323..c88dbb9ad 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -15,6 +15,7 @@ """MediaPipe Tasks Vision API.""" import mediapipe.tasks.python.vision.core +import mediapipe.tasks.python.vision.face_aligner import mediapipe.tasks.python.vision.face_detector import mediapipe.tasks.python.vision.face_landmarker import mediapipe.tasks.python.vision.face_stylizer @@ -25,7 +26,10 @@ import mediapipe.tasks.python.vision.image_embedder import mediapipe.tasks.python.vision.image_segmenter import mediapipe.tasks.python.vision.interactive_segmenter import mediapipe.tasks.python.vision.object_detector +import mediapipe.tasks.python.vision.pose_landmarker +FaceAligner = face_aligner.FaceAligner +FaceAlignerOptions = face_aligner.FaceAlignerOptions FaceDetector = face_detector.FaceDetector FaceDetectorOptions = face_detector.FaceDetectorOptions FaceDetectorResult = face_detector.FaceDetectorResult @@ -41,6 +45,7 @@ GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult HandLandmarker = hand_landmarker.HandLandmarker HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions HandLandmarkerResult = hand_landmarker.HandLandmarkerResult +HandLandmarksConnections = hand_landmarker.HandLandmarksConnections ImageClassifier = image_classifier.ImageClassifier ImageClassifierOptions = image_classifier.ImageClassifierOptions ImageClassifierResult = image_classifier.ImageClassifierResult @@ -54,10 +59,16 @@ InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions InteractiveSegmenterRegionOfInterest = interactive_segmenter.RegionOfInterest ObjectDetector = object_detector.ObjectDetector ObjectDetectorOptions = object_detector.ObjectDetectorOptions +ObjectDetectorResult = object_detector.ObjectDetectorResult +PoseLandmarker = pose_landmarker.PoseLandmarker +PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions +PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult +PoseLandmarksConnections = pose_landmarker.PoseLandmarksConnections RunningMode = core.vision_task_running_mode.VisionTaskRunningMode # Remove unnecessary modules to avoid duplication in API docs. del core +del face_aligner del face_detector del face_landmarker del face_stylizer @@ -68,4 +79,5 @@ del image_embedder del image_segmenter del interactive_segmenter del object_detector +del pose_landmarker del mediapipe diff --git a/mediapipe/tasks/python/vision/face_aligner.py b/mediapipe/tasks/python/vision/face_aligner.py new file mode 100644 index 000000000..53bf71f35 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_aligner.py @@ -0,0 +1,158 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face aligner task.""" + +import dataclasses +from typing import Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.cc.vision.face_stylizer.proto import face_stylizer_graph_options_pb2 +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceStylizerGraphOptionsProto = ( + face_stylizer_graph_options_pb2.FaceStylizerGraphOptions +) +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_FACE_ALIGNMENT_IMAGE_NAME = 'face_alignment' +_FACE_ALIGNMENT_IMAGE_TAG = 'FACE_ALIGNMENT' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph' + + +@dataclasses.dataclass +class FaceAlignerOptions: + """Options for the face aligner task. + + Attributes: + base_options: Base options for the face aligner task. + """ + + base_options: _BaseOptions + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceStylizerGraphOptionsProto: + """Generates a FaceStylizerOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False + return _FaceStylizerGraphOptionsProto(base_options=base_options_proto) + + +class FaceAligner(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face alignment on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceAligner': + """Creates a `FaceAligner` object from a face landmarker task bundle and the default `FaceAlignerOptions`. + + Note that the created `FaceAligner` instance is in image mode, for + aligning one face on a single image input. + + Args: + model_path: Path to the face landmarker task bundle. + + Returns: + `FaceAligner` object that's created from the model file and the default + `FaceAlignerOptions`. + + Raises: + ValueError: If failed to create `FaceAligner` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceAlignerOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, options: FaceAlignerOptions) -> 'FaceAligner': + """Creates the `FaceAligner` object from face aligner options. + + Args: + options: Options for the face aligner task. + + Returns: + `FaceAligner` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceAligner` object from + `FaceAlignerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_FACE_ALIGNMENT_IMAGE_TAG, _FACE_ALIGNMENT_IMAGE_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ], + task_options=options, + ) + return cls( + task_info.generate_graph_config(enable_flow_limiting=False), + _RunningMode.IMAGE, + None, + ) + + def align( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> image_module.Image: + """Performs face alignment on the provided MediaPipe Image. + + Only use this method when the FaceAligner is created with the image + running mode. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The aligned face image. The aligned output image size is the same as the + model output size. None if no face is detected on the input image. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face alignment failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, image + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + if output_packets[_FACE_ALIGNMENT_IMAGE_NAME].is_empty(): + return None + return packet_getter.get_image(output_packets[_FACE_ALIGNMENT_IMAGE_NAME]) diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 870e7e43e..44ddba87e 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -2939,7 +2939,7 @@ class FaceLandmarkerOptions: Attributes: base_options: Base options for the face landmarker task. running_mode: The running mode of the task. Default to the image mode. - HandLandmarker has three running modes: 1) The image mode for detecting + FaceLandmarker has three running modes: 1) The image mode for detecting face landmarks on single image inputs. 2) The video mode for detecting face landmarks on the decoded frames of a video. 3) The live stream mode for detecting face landmarks on the live stream of input data, such as diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 1f2c629d2..e781c8882 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -82,6 +82,65 @@ class HandLandmark(enum.IntEnum): PINKY_TIP = 20 +class HandLandmarksConnections: + """The connections between hand landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for hand landmarks.""" + + start: int + end: int + + HAND_PALM_CONNECTIONS: List[Connection] = [ + Connection(0, 1), + Connection(1, 5), + Connection(9, 13), + Connection(13, 17), + Connection(5, 9), + Connection(0, 17), + ] + + HAND_THUMB_CONNECTIONS: List[Connection] = [ + Connection(1, 2), + Connection(2, 3), + Connection(3, 4), + ] + + HAND_INDEX_FINGER_CONNECTIONS: List[Connection] = [ + Connection(5, 6), + Connection(6, 7), + Connection(7, 8), + ] + + HAND_MIDDLE_FINGER_CONNECTIONS: List[Connection] = [ + Connection(9, 10), + Connection(10, 11), + Connection(11, 12), + ] + + HAND_RING_FINGER_CONNECTIONS: List[Connection] = [ + Connection(13, 14), + Connection(14, 15), + Connection(15, 16), + ] + + HAND_PINKY_FINGER_CONNECTIONS: List[Connection] = [ + Connection(17, 18), + Connection(18, 19), + Connection(19, 20), + ] + + HAND_CONNECTIONS: List[Connection] = ( + HAND_PALM_CONNECTIONS + + HAND_THUMB_CONNECTIONS + + HAND_INDEX_FINGER_CONNECTIONS + + HAND_MIDDLE_FINGER_CONNECTIONS + + HAND_RING_FINGER_CONNECTIONS + + HAND_PINKY_FINGER_CONNECTIONS + ) + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index 9ee3cb467..0cca4b0f6 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -88,7 +88,7 @@ class InteractiveSegmenterOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: - """Generates an InteractiveSegmenterOptions protobuf object.""" + """Generates an ImageSegmenterGraphOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False segmenter_options_proto = _SegmenterOptionsProto() diff --git a/mediapipe/tasks/python/vision/pose_landmarker.py b/mediapipe/tasks/python/vision/pose_landmarker.py index b91eb0326..3ff7edb0a 100644 --- a/mediapipe/tasks/python/vision/pose_landmarker.py +++ b/mediapipe/tasks/python/vision/pose_landmarker.py @@ -132,6 +132,55 @@ def _build_landmarker_result( return pose_landmarker_result +class PoseLandmarksConnections: + """The connections between pose landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for pose landmarks.""" + + start: int + end: int + + POSE_LANDMARKS: List[Connection] = [ + Connection(0, 1), + Connection(1, 2), + Connection(2, 3), + Connection(3, 7), + Connection(0, 4), + Connection(4, 5), + Connection(5, 6), + Connection(6, 8), + Connection(9, 10), + Connection(11, 12), + Connection(11, 13), + Connection(13, 15), + Connection(15, 17), + Connection(15, 19), + Connection(15, 21), + Connection(17, 19), + Connection(12, 14), + Connection(14, 16), + Connection(16, 18), + Connection(16, 20), + Connection(16, 22), + Connection(18, 20), + Connection(11, 23), + Connection(12, 24), + Connection(23, 24), + Connection(23, 25), + Connection(24, 26), + Connection(25, 27), + Connection(26, 28), + Connection(27, 29), + Connection(28, 30), + Connection(29, 31), + Connection(30, 32), + Connection(27, 31), + Connection(28, 32) + ] + + @dataclasses.dataclass class PoseLandmarkerOptions: """Options for the pose landmarker task. diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 5188da896..d8516f9bc 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/vector.h" #include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" namespace mediapipe { namespace { @@ -112,6 +113,8 @@ void AnnotationRenderer::RenderDataOnImage(const RenderData& render_data) { DrawGradientLine(annotation); } else if (annotation.data_case() == RenderAnnotation::kArrow) { DrawArrow(annotation); + } else if (annotation.data_case() == RenderAnnotation::kScribble) { + DrawScribble(annotation); } else { LOG(FATAL) << "Unknown annotation type: " << annotation.data_case(); } @@ -442,7 +445,11 @@ void AnnotationRenderer::DrawArrow(const RenderAnnotation& annotation) { } void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { - const auto& point = annotation.point(); + DrawPoint(annotation.point(), annotation); +} + +void AnnotationRenderer::DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation) { int x = -1; int y = -1; if (point.normalized()) { @@ -460,6 +467,12 @@ void AnnotationRenderer::DrawPoint(const RenderAnnotation& annotation) { cv::circle(mat_image_, point_to_draw, thickness, color, -1); } +void AnnotationRenderer::DrawScribble(const RenderAnnotation& annotation) { + for (const RenderAnnotation::Point& point : annotation.scribble().point()) { + DrawPoint(point, annotation); + } +} + void AnnotationRenderer::DrawLine(const RenderAnnotation& annotation) { int x_start = -1; int y_start = -1; diff --git a/mediapipe/util/annotation_renderer.h b/mediapipe/util/annotation_renderer.h index 380bc3614..ae0cf976e 100644 --- a/mediapipe/util/annotation_renderer.h +++ b/mediapipe/util/annotation_renderer.h @@ -96,6 +96,11 @@ class AnnotationRenderer { // Draws a point on the image as described in the annotation. void DrawPoint(const RenderAnnotation& annotation); + void DrawPoint(const RenderAnnotation::Point& point, + const RenderAnnotation& annotation); + + // Draws scribbles on the image as described in the annotation. + void DrawScribble(const RenderAnnotation& annotation); // Draws a line segment on the image as described in the annotation. void DrawLine(const RenderAnnotation& annotation); diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index fee02fff3..897d5fa37 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -131,6 +131,10 @@ message RenderAnnotation { optional Color color2 = 7; } + message Scribble { + repeated Point point = 1; + } + message Arrow { // The arrow head will be drawn at (x_end, y_end). optional double x_start = 1; @@ -192,6 +196,7 @@ message RenderAnnotation { RoundedRectangle rounded_rectangle = 9; FilledRoundedRectangle filled_rounded_rectangle = 10; GradientLine gradient_line = 14; + Scribble scribble = 15; } // Thickness for drawing the annotation.