Add image_segmenter namespace and update image segmenter proto for java package.

PiperOrigin-RevId: 486792164
This commit is contained in:
MediaPipe Team 2022-11-07 16:17:25 -08:00 committed by Copybara-Service
parent 1a151d9ffb
commit 1049ef781d
11 changed files with 53 additions and 34 deletions

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.proto;
option java_package = "com.google.mediapipe.tasks.components.proto";
option java_outer_classname = "SegmenterOptionsProto";
// Shared options used by image segmentation tasks.
message SegmenterOptions {
// Optional output mask type.

View File

@ -32,7 +32,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
@ -63,7 +63,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",

View File

@ -23,10 +23,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
namespace {
constexpr char kSegmentationStreamName[] = "segmented_mask_out";
@ -37,23 +39,24 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageSegmenterGraph";
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::tasks::components::proto::SegmenterOptions;
using ImageSegmenterOptionsProto =
image_segmenter::proto::ImageSegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.vision.ImageSegmenterGraph".
// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageSegmenterOptionsProto> options,
std::unique_ptr<ImageSegmenterGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
task_subgraph.GetOptions<ImageSegmenterGraphOptionsProto>().Swap(
options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
@ -72,9 +75,9 @@ CalculatorGraphConfig CreateGraphConfig(
// Converts the user-facing ImageSegmenterOptions struct to the internal
// ImageSegmenterOptions proto.
std::unique_ptr<ImageSegmenterOptionsProto> ConvertImageSegmenterOptionsToProto(
ImageSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterOptionsProto>();
std::unique_ptr<ImageSegmenterGraphOptionsProto>
ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) {
auto options_proto = std::make_unique<ImageSegmenterGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -137,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
};
}
return core::VisionTaskApiFactory::Create<ImageSegmenter,
ImageSegmenterOptionsProto>(
ImageSegmenterGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
@ -211,6 +214,7 @@ absl::Status ImageSegmenter::SegmentAsync(
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -26,12 +26,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
// The options for configuring a mediapipe image segmenter task.
struct ImageSegmenterOptions {
@ -191,6 +191,7 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); }
};
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
namespace {
@ -55,7 +56,8 @@ using ::mediapipe::api2::builder::MultiSource;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::SegmenterOptions;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions;
using ::mediapipe::tasks::vision::image_segmenter::proto::
ImageSegmenterGraphOptions;
using ::tflite::Tensor;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
@ -77,7 +79,7 @@ struct ImageSegmenterOutputs {
} // namespace
absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) {
absl::Status SanityCheckOptions(const ImageSegmenterGraphOptions& options) {
if (options.segmenter_options().output_type() ==
SegmenterOptions::UNSPECIFIED) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -112,7 +114,7 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny(
}
absl::Status ConfigureTensorsToSegmentationCalculator(
const ImageSegmenterOptions& segmenter_option,
const ImageSegmenterGraphOptions& segmenter_option,
const core::ModelResources& model_resources,
TensorsToSegmentationCalculatorOptions* options) {
*options->mutable_segmenter_options() = segmenter_option.segmenter_options();
@ -181,7 +183,7 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
// input_stream: "IMAGE:image"
// output_stream: "SEGMENTATION:segmented_masks"
// options {
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext]
// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -200,12 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterOptions>(sc));
CreateModelResources<ImageSegmenterGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildSegmentationTask(
sc->Options<ImageSegmenterOptions>(), *model_resources,
sc->Options<ImageSegmenterGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
@ -228,13 +230,13 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
// builder::Graph instance. The segmentation pipeline takes images
// (mediapipe::Image) as the input and returns segmented image mask as output.
//
// task_options: the mediapipe tasks ImageSegmenterOptions proto.
// task_options: the mediapipe tasks ImageSegmenterGraphOptions proto.
// model_resources: the ModelSources object initialized from a segmentation
// model file with model metadata.
// image_in: (mediapipe::Image) stream to run segmentation on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterOptions& task_options,
const ImageSegmenterGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
@ -293,8 +295,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageSegmenterGraph);
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_segmenter::ImageSegmenterGraph);
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
@ -42,6 +42,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_segmenter {
namespace {
using ::mediapipe::Image;
@ -547,6 +548,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
// TODO: Add test for hair segmentation model.
} // namespace
} // namespace image_segmenter
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "image_segmenter_options_proto",
srcs = ["image_segmenter_options.proto"],
name = "image_segmenter_graph_options_proto",
srcs = ["image_segmenter_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,12 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/segmenter_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageSegmenterOptions {
option java_package = "com.google.mediapipe.tasks.vision.imagesegmenter.proto";
option java_outer_classname = "ImageSegmenterGraphOptionsProto";
message ImageSegmenterGraphOptions {
extend mediapipe.CalculatorOptions {
optional ImageSegmenterOptions ext = 458105758;
optional ImageSegmenterGraphOptions ext = 458105758;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.

View File

@ -22,6 +22,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
@ -36,6 +37,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",

View File

@ -70,7 +70,7 @@ py_library(
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",

View File

@ -22,7 +22,7 @@ from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -31,7 +31,7 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
_ImageSegmenterOptionsProto = image_segmenter_options_pb2.ImageSegmenterOptions
_ImageSegmenterGraphOptionsProto = image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
@ -40,7 +40,7 @@ _SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@ -81,13 +81,13 @@ class ImageSegmenterOptions:
[List[image_module.Image], image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterOptionsProto:
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an ImageSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
segmenter_options_proto = _SegmenterOptionsProto(
output_type=self.output_type.value, activation=self.activation.value)
return _ImageSegmenterOptionsProto(
return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto,
segmenter_options=segmenter_options_proto)