internal update
PiperOrigin-RevId: 533197055
This commit is contained in:
parent
a1755044ea
commit
c248525eeb
|
@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
|
||||||
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
||||||
"CONFIDENCE_MASK"};
|
"CONFIDENCE_MASK"};
|
||||||
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
||||||
|
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
|
||||||
|
"QUALITY_SCORES"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
||||||
kConfidenceMaskOut, kCategoryMaskOut);
|
kConfidenceMaskOut, kCategoryMaskOut,
|
||||||
|
kQualityScoresOut);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
|
||||||
|
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
||||||
|
|
||||||
absl::Status TensorsToSegmentationCalculator::Process(
|
absl::Status TensorsToSegmentationCalculator::Process(
|
||||||
mediapipe::CalculatorContext* cc) {
|
mediapipe::CalculatorContext* cc) {
|
||||||
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
|
const auto& input_tensors = kTensorsIn(cc).Get();
|
||||||
<< "Expect a vector of single Tensor.";
|
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
|
||||||
const auto& input_tensor = kTensorsIn(cc).Get()[0];
|
return absl::InvalidArgumentError(
|
||||||
|
"Expect input tensor vector of size 1 or 2.");
|
||||||
|
}
|
||||||
|
const auto& input_tensor = *input_tensors.rbegin();
|
||||||
ASSIGN_OR_RETURN(const Shape input_shape,
|
ASSIGN_OR_RETURN(const Shape input_shape,
|
||||||
GetImageLikeTensorShape(input_tensor));
|
GetImageLikeTensorShape(input_tensor));
|
||||||
|
|
||||||
|
// TODO: should use tensor signature to get the correct output
|
||||||
|
// tensor.
|
||||||
|
if (input_tensors.size() == 2) {
|
||||||
|
const auto& quality_tensor = input_tensors[0];
|
||||||
|
const float* quality_score_buffer =
|
||||||
|
quality_tensor.GetCpuReadView().buffer<float>();
|
||||||
|
const std::vector<float> quality_scores(
|
||||||
|
quality_score_buffer,
|
||||||
|
quality_score_buffer +
|
||||||
|
(quality_tensor.bytes() / quality_tensor.element_size()));
|
||||||
|
kQualityScoresOut(cc).Send(quality_scores);
|
||||||
|
} else {
|
||||||
|
// If the input_tensors don't contain quality scores, send the default
|
||||||
|
// quality scores as 1.
|
||||||
|
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
|
||||||
|
kQualityScoresOut(cc).Send(quality_scores);
|
||||||
|
}
|
||||||
|
|
||||||
// Category mask does not require activation function.
|
// Category mask does not require activation function.
|
||||||
if (options_.segmenter_options().output_type() ==
|
if (options_.segmenter_options().output_type() ==
|
||||||
SegmenterOptions::CONFIDENCE_MASK &&
|
SegmenterOptions::CONFIDENCE_MASK &&
|
||||||
|
|
|
@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||||
|
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
graph.Out(kCategoryMaskTag);
|
graph.Out(kCategoryMaskTag);
|
||||||
}
|
}
|
||||||
|
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||||
|
graph.Out(kQualityScoresTag);
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
|
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
category_mask =
|
category_mask =
|
||||||
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
|
const std::vector<float>& quality_scores =
|
||||||
|
status_or_packets.value()[kQualityScoresStreamName]
|
||||||
|
.Get<std::vector<float>>();
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(
|
result_callback(
|
||||||
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
|
{{confidence_masks, category_mask, quality_scores}},
|
||||||
|
image_packet.Get<Image>(),
|
||||||
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
|
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageSegmenter::SegmentAsync(
|
absl::Status ImageSegmenter::SegmentAsync(
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||||
|
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||||
|
|
||||||
// Struct holding the different output streams produced by the image segmenter
|
// Struct holding the different output streams produced by the image segmenter
|
||||||
|
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
|
||||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
std::optional<Source<Image>> category_mask;
|
std::optional<Source<Image>> category_mask;
|
||||||
// The same as the input image, mainly used for live stream mode.
|
// The same as the input image, mainly used for live stream mode.
|
||||||
|
std::optional<Source<std::vector<float>>> quality_scores;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -191,18 +194,11 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
||||||
"Segmentation tflite models are assumed to have a single subgraph.",
|
"Segmentation tflite models are assumed to have a single subgraph.",
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
|
||||||
if (primary_subgraph->outputs()->size() != 1) {
|
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"Segmentation tflite models are assumed to have a single output.",
|
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
|
||||||
}
|
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
*options->mutable_label_items(),
|
*options->mutable_label_items(),
|
||||||
GetLabelItemsIfAny(*metadata_extractor,
|
GetLabelItemsIfAny(
|
||||||
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
|
*metadata_extractor,
|
||||||
|
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
|
||||||
segmenter_option.display_names_locale()));
|
segmenter_option.display_names_locale()));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
|
||||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
const auto* output_tensor =
|
const auto* output_tensor =
|
||||||
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
|
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
|
||||||
return output_tensor;
|
return output_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
|
||||||
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
return primary_subgraph->outputs()->size();
|
||||||
|
}
|
||||||
|
|
||||||
// Get the input tensor from the tflite model of given model resources.
|
// Get the input tensor from the tflite model of given model resources.
|
||||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
||||||
const core::ModelResources& model_resources) {
|
const core::ModelResources& model_resources) {
|
||||||
|
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (output_streams.quality_scores) {
|
||||||
|
*output_streams.quality_scores >>
|
||||||
|
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||||
|
}
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto quality_scores =
|
||||||
|
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||||
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||||
/*confidence_masks=*/std::nullopt,
|
/*confidence_masks=*/std::nullopt,
|
||||||
/*category_mask=*/std::nullopt,
|
/*category_mask=*/std::nullopt,
|
||||||
|
/*quality_scores=*/quality_scores,
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
} else {
|
} else {
|
||||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||||
|
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
|
auto quality_scores =
|
||||||
|
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||||
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
||||||
/*confidence_masks=*/confidence_masks,
|
/*confidence_masks=*/confidence_masks,
|
||||||
/*category_mask=*/category_mask,
|
/*category_mask=*/category_mask,
|
||||||
|
/*quality_scores=*/quality_scores,
|
||||||
/*image=*/image_and_tensors.image};
|
/*image=*/image_and_tensors.image};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,10 @@ struct ImageSegmenterResult {
|
||||||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||||
// the class which the pixel in the original image was predicted to belong to.
|
// the class which the pixel in the original image was predicted to belong to.
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
|
// The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||||
|
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||||
|
// the score of the category in the model outputs.
|
||||||
|
std::vector<float> quality_scores;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_segmenter
|
} // namespace image_segmenter
|
||||||
|
|
|
@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kRoiStreamName[] = "roi_in";
|
constexpr char kRoiStreamName[] = "roi_in";
|
||||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
|
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||||
|
|
||||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||||
constexpr absl::string_view kRoiTag{"ROI"};
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
|
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||||
|
|
||||||
constexpr absl::string_view kSubgraphTypeName{
|
constexpr absl::string_view kSubgraphTypeName{
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||||
|
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||||
graph.Out(kCategoryMaskTag);
|
graph.Out(kCategoryMaskTag);
|
||||||
}
|
}
|
||||||
|
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||||
|
graph.Out(kQualityScoresTag);
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
||||||
if (output_category_mask_) {
|
if (output_category_mask_) {
|
||||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||||
}
|
}
|
||||||
return {{confidence_masks, category_mask}};
|
const std::vector<float>& quality_scores =
|
||||||
|
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||||
|
return {{confidence_masks, category_mask, quality_scores}};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace interactive_segmenter
|
} // namespace interactive_segmenter
|
||||||
|
|
|
@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
|
||||||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
constexpr absl::string_view kRoiTag{"ROI"};
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
|
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||||
|
|
||||||
// Updates the graph to return `roi` stream which has same dimension as
|
// Updates the graph to return `roi` stream which has same dimension as
|
||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||||
|
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||||
graph[Output<Image>(kCategoryMaskTag)];
|
graph[Output<Image>(kCategoryMaskTag)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
image_segmenter.Out(kQualityScoresTag) >>
|
||||||
|
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||||
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
|
|
|
@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
segmenterOptions.outputCategoryMask()
|
segmenterOptions.outputCategoryMask()
|
||||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||||
: -1;
|
: -1;
|
||||||
|
final int qualityScoresOutStreamIndex =
|
||||||
|
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
|
||||||
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||||
|
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
|
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
|
new ArrayList<>(),
|
||||||
packets.get(imageOutStreamIndex).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||||
|
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||||
categoryMask = Optional.of(builder.build());
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
float[] qualityScores =
|
||||||
|
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||||
|
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||||
|
for (float score : qualityScores) {
|
||||||
|
qualityScoresList.add(score);
|
||||||
|
}
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
|
qualityScoresList,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
public abstract Builder setOutputCategoryMask(boolean value);
|
public abstract Builder setOutputCategoryMask(boolean value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
|
||||||
* pipeline is done processing an image.
|
* graph pipeline is done processing an image.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setResultListener(
|
public abstract Builder setResultListener(
|
||||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||||
|
|
|
@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
||||||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||||
* category mask, where each pixel represents the class which the pixel in the original image
|
* category mask, where each pixel represents the class which the pixel in the original image
|
||||||
* was predicted to belong to.
|
* was predicted to belong to.
|
||||||
|
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||||
|
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||||
|
* the category in the model outputs.
|
||||||
* @param timestampMs a timestamp for this result.
|
* @param timestampMs a timestamp for this result.
|
||||||
*/
|
*/
|
||||||
// TODO: consolidate output formats across platforms.
|
// TODO: consolidate output formats across platforms.
|
||||||
public static ImageSegmenterResult create(
|
public static ImageSegmenterResult create(
|
||||||
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
|
Optional<List<MPImage>> confidenceMasks,
|
||||||
|
Optional<MPImage> categoryMask,
|
||||||
|
List<Float> qualityScores,
|
||||||
|
long timestampMs) {
|
||||||
return new AutoValue_ImageSegmenterResult(
|
return new AutoValue_ImageSegmenterResult(
|
||||||
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
|
confidenceMasks.map(Collections::unmodifiableList),
|
||||||
|
categoryMask,
|
||||||
|
Collections.unmodifiableList(qualityScores),
|
||||||
|
timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Optional<List<MPImage>> confidenceMasks();
|
public abstract Optional<List<MPImage>> confidenceMasks();
|
||||||
|
|
||||||
public abstract Optional<MPImage> categoryMask();
|
public abstract Optional<MPImage> categoryMask();
|
||||||
|
|
||||||
|
public abstract List<Float> qualityScores();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public abstract long timestampMs();
|
public abstract long timestampMs();
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
outputStreams.add("CATEGORY_MASK:category_mask");
|
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||||
}
|
}
|
||||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
|
outputStreams.add("QUALITY_SCORES:quality_scores");
|
||||||
|
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
outputStreams.add("IMAGE:image_out");
|
outputStreams.add("IMAGE:image_out");
|
||||||
// TODO: add test for stream indices.
|
// TODO: add test for stream indices.
|
||||||
final int imageOutStreamIndex = outputStreams.size() - 1;
|
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
|
new ArrayList<>(),
|
||||||
packets.get(imageOutStreamIndex).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||||
|
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
categoryMask = Optional.of(builder.build());
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float[] qualityScores =
|
||||||
|
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||||
|
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||||
|
for (float score : qualityScores) {
|
||||||
|
qualityScoresList.add(score);
|
||||||
|
}
|
||||||
|
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
|
qualityScoresList,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user