diff --git a/mediapipe/tasks/cc/components/proto/segmenter_options.proto b/mediapipe/tasks/cc/components/proto/segmenter_options.proto index a2f37d3a0..ca9986707 100644 --- a/mediapipe/tasks/cc/components/proto/segmenter_options.proto +++ b/mediapipe/tasks/cc/components/proto/segmenter_options.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.proto; +option java_package = "com.google.mediapipe.tasks.components.proto"; +option java_outer_classname = "SegmenterOptionsProto"; + // Shared options used by image segmentation tasks. message SegmenterOptions { // Optional output mask type. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 81cd43e34..4c43a07f5 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -32,7 +32,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", @@ -63,7 +63,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 209ee0df3..6dce1b4ea 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -23,10 +23,12 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_segmenter { namespace { constexpr char kSegmentationStreamName[] = "segmented_mask_out"; @@ -37,23 +39,24 @@ constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageSegmenterGraph"; + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::tasks::components::proto::SegmenterOptions; -using ImageSegmenterOptionsProto = - image_segmenter::proto::ImageSegmenterOptions; +using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: + image_segmenter::proto::ImageSegmenterGraphOptions; // Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.vision.ImageSegmenterGraph". +// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options, + std::unique_ptr options, bool enable_flow_limiting) { api2::builder::Graph graph; auto& task_subgraph = graph.AddNode(kSubgraphTypeName); - task_subgraph.GetOptions().Swap(options.get()); + task_subgraph.GetOptions().Swap( + options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> @@ -72,9 +75,9 @@ CalculatorGraphConfig CreateGraphConfig( // Converts the user-facing ImageSegmenterOptions struct to the internal // ImageSegmenterOptions proto. -std::unique_ptr ConvertImageSegmenterOptionsToProto( - ImageSegmenterOptions* options) { - auto options_proto = std::make_unique(); +std::unique_ptr +ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { + auto options_proto = std::make_unique(); auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); @@ -137,7 +140,7 @@ absl::StatusOr> ImageSegmenter::Create( }; } return core::VisionTaskApiFactory::Create( + ImageSegmenterGraphOptionsProto>( CreateGraphConfig( std::move(options_proto), options->running_mode == core::RunningMode::LIVE_STREAM), @@ -211,6 +214,7 @@ absl::Status ImageSegmenter::SegmentAsync( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } +} // namespace image_segmenter } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 54269ec0e..43bf5b7e6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,12 +26,12 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_segmenter { // The options for configuring a mediapipe image segmenter task. struct ImageSegmenterOptions { @@ -191,6 +191,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } }; +} // namespace image_segmenter } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 31fed6d8d..44742e043 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -44,6 +44,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_segmenter { namespace { @@ -55,7 +56,8 @@ using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::proto::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; -using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions; +using ::mediapipe::tasks::vision::image_segmenter::proto:: + ImageSegmenterGraphOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; @@ -77,7 +79,7 @@ struct ImageSegmenterOutputs { } // namespace -absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) { +absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) { if (options.segmenter_options().output_type() == SegmenterOptions::UNSPECIFIED) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -112,7 +114,7 @@ absl::StatusOr GetLabelItemsIfAny( } absl::Status ConfigureTensorsToSegmentationCalculator( - const ImageSegmenterOptions& segmenter_option, + const ImageSegmenterGraphOptions& segmenter_option, const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { *options->mutable_segmenter_options() = segmenter_option.segmenter_options(); @@ -181,7 +183,7 @@ absl::StatusOr GetOutputTensor( // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { -// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext] +// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext] // { // base_options { // model_asset { @@ -200,12 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( - sc->Options(), *model_resources, + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); @@ -228,13 +230,13 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. // - // task_options: the mediapipe tasks ImageSegmenterOptions proto. + // task_options: the mediapipe tasks ImageSegmenterGraphOptions proto. // model_resources: the ModelSources object initialized from a segmentation // model file with model metadata. // image_in: (mediapipe::Image) stream to run segmentation on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildSegmentationTask( - const ImageSegmenterOptions& task_options, + const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); @@ -293,8 +295,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageSegmenterGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterGraph); +} // namespace image_segmenter } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 07235563b..752a116dd 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" @@ -42,6 +42,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_segmenter { namespace { using ::mediapipe::Image; @@ -547,6 +548,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { // TODO: Add test for hair segmentation model. } // namespace +} // namespace image_segmenter } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD index d768c2bb1..3b14060f1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "image_segmenter_options_proto", - srcs = ["image_segmenter_options.proto"], + name = "image_segmenter_graph_options_proto", + srcs = ["image_segmenter_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto similarity index 85% rename from mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto index 6e24a6665..166e2e8e0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.proto @@ -21,9 +21,12 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message ImageSegmenterOptions { +option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto"; +option java_outer_classname = "ImageSegmenterGraphOptionsProto"; + +message ImageSegmenterGraphOptions { extend mediapipe.CalculatorOptions { - optional ImageSegmenterOptions ext = 458105758; + optional ImageSegmenterGraphOptions ext = 458105758; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index ff40768f7..16308e71f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -22,6 +22,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", + "//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", @@ -36,6 +37,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 9a46a31a8..527c6d883 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -70,7 +70,7 @@ py_library( "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", "//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2", - "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", "//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:task_info", diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index c1b50a5ae..1740d41ef 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -22,7 +22,7 @@ from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet from mediapipe.tasks.cc.components.proto import segmenter_options_pb2 -from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2 +from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls @@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode _BaseOptions = base_options_module.BaseOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions -_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions +_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions _RunningMode = vision_task_running_mode.VisionTaskRunningMode _TaskInfo = task_info_module.TaskInfo @@ -40,7 +40,7 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 @@ -81,13 +81,13 @@ class ImageSegmenterOptions: [List[image_module.Image], image_module.Image, int], None]] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _ImageSegmenterOptionsProto: + def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: """Generates an ImageSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True segmenter_options_proto = _SegmenterOptionsProto( output_type=self.output_type.value, activation=self.activation.value) - return _ImageSegmenterOptionsProto( + return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto)