internal update

PiperOrigin-RevId: 533197055
This commit is contained in:
MediaPipe Team 2023-05-18 11:37:13 -07:00 committed by Copybara-Service
parent a1755044ea
commit c248525eeb
9 changed files with 119 additions and 24 deletions

View File

@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
"CONFIDENCE_MASK"};
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
"QUALITY_SCORES"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
kConfidenceMaskOut, kCategoryMaskOut);
kConfidenceMaskOut, kCategoryMaskOut,
kQualityScoresOut);
static absl::Status UpdateContract(CalculatorContract* cc);
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
absl::Status TensorsToSegmentationCalculator::Process(
mediapipe::CalculatorContext* cc) {
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
<< "Expect a vector of single Tensor.";
const auto& input_tensor = kTensorsIn(cc).Get()[0];
const auto& input_tensors = kTensorsIn(cc).Get();
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
return absl::InvalidArgumentError(
"Expect input tensor vector of size 1 or 2.");
}
const auto& input_tensor = *input_tensors.rbegin();
ASSIGN_OR_RETURN(const Shape input_shape,
GetImageLikeTensorShape(input_tensor));
// TODO: should use tensor signature to get the correct output
// tensor.
if (input_tensors.size() == 2) {
const auto& quality_tensor = input_tensors[0];
const float* quality_score_buffer =
quality_tensor.GetCpuReadView().buffer<float>();
const std::vector<float> quality_scores(
quality_score_buffer,
quality_score_buffer +
(quality_tensor.bytes() / quality_tensor.element_size()));
kQualityScoresOut(cc).Send(quality_scores);
} else {
// If the input_tensors don't contain quality scores, send the default
// quality scores as 1.
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
kQualityScoresOut(cc).Send(quality_scores);
}
// Category mask does not require activation function.
if (options_.segmenter_options().output_type() ==
SegmenterOptions::CONFIDENCE_MASK &&

View File

@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
category_mask =
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
}
const std::vector<float>& quality_scores =
status_or_packets.value()[kQualityScoresStreamName]
.Get<std::vector<float>>();
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
{{confidence_masks, category_mask, quality_scores}},
image_packet.Get<Image>(),
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
};
}
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
absl::Status ImageSegmenter::SegmentAsync(

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <optional>
#include <type_traits>
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
// Struct holding the different output streams produced by the image segmenter
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
std::optional<std::vector<Source<Image>>> confidence_masks;
std::optional<Source<Image>> category_mask;
// The same as the input image, mainly used for live stream mode.
std::optional<Source<std::vector<float>>> quality_scores;
Source<Image> image;
};
@ -191,19 +194,12 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
"Segmentation tflite models are assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
if (primary_subgraph->outputs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Segmentation tflite models are assumed to have a single output.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
ASSIGN_OR_RETURN(
*options->mutable_label_items(),
GetLabelItemsIfAny(*metadata_extractor,
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
segmenter_option.display_names_locale()));
GetLabelItemsIfAny(
*metadata_extractor,
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
segmenter_option.display_names_locale()));
return absl::OkStatus();
}
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
const auto* output_tensor =
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
return output_tensor;
}
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
const auto* primary_subgraph = (*model.subgraphs())[0];
return primary_subgraph->outputs()->size();
}
// Get the input tensor from the tflite model of given model resources.
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
const core::ModelResources& model_resources) {
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
}
}
if (output_streams.quality_scores) {
*output_streams.quality_scores >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
}
output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
}
}
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
/*confidence_masks=*/std::nullopt,
/*category_mask=*/std::nullopt,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image};
} else {
std::optional<std::vector<Source<Image>>> confidence_masks;
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
if (output_category_mask_) {
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
}
auto quality_scores =
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
/*confidence_masks=*/confidence_masks,
/*category_mask=*/category_mask,
/*quality_scores=*/quality_scores,
/*image=*/image_and_tensors.image};
}
}

View File

@ -33,6 +33,10 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Default to
// `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs.
std::vector<float> quality_scores;
};
} // namespace image_segmenter

View File

@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kRoiStreamName[] = "roi_in";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kQualityScoresStreamName[] = "quality_scores";
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
constexpr absl::string_view kImageTag{"IMAGE"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
constexpr absl::string_view kSubgraphTypeName{
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
graph.Out(kCategoryMaskTag);
}
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
graph.Out(kQualityScoresTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
if (output_category_mask_) {
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
}
return {{confidence_masks, category_mask}};
const std::vector<float>& quality_scores =
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
return {{confidence_masks, category_mask, quality_scores}};
}
} // namespace interactive_segmenter

View File

@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
graph[Output<Image>(kCategoryMaskTag)];
}
}
image_segmenter.Out(kQualityScoresTag) >>
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();

View File

@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1;
final int qualityScoresOutStreamIndex =
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
// TODO: Consolidate OutputHandler and TaskRunner.
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp());
}
boolean copyImage = !segmenterOptions.resultListener().isPresent();
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
categoryMask = Optional.of(builder.build());
}
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create(
confidenceMasks,
categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
}
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
public abstract Builder setOutputCategoryMask(boolean value);
/**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
* pipeline is done processing an image.
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
* graph pipeline is done processing an image.
*/
public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value);

View File

@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
public static ImageSegmenterResult create(
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
Optional<List<MPImage>> confidenceMasks,
Optional<MPImage> categoryMask,
List<Float> qualityScores,
long timestampMs) {
return new AutoValue_ImageSegmenterResult(
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
confidenceMasks.map(Collections::unmodifiableList),
categoryMask,
Collections.unmodifiableList(qualityScores),
timestampMs);
}
public abstract Optional<List<MPImage>> confidenceMasks();
public abstract Optional<MPImage> categoryMask();
public abstract List<Float> qualityScores();
@Override
public abstract long timestampMs();
}

View File

@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("QUALITY_SCORES:quality_scores");
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("IMAGE:image_out");
// TODO: add test for stream indices.
final int imageOutStreamIndex = outputStreams.size() - 1;
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
new ArrayList<>(),
packets.get(imageOutStreamIndex).getTimestamp());
}
// If resultListener is not provided, the resulted MPImage is deep copied from
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
categoryMask = Optional.of(builder.build());
}
float[] qualityScores =
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
for (float score : qualityScores) {
qualityScoresList.add(score);
}
return ImageSegmenterResult.create(
confidenceMasks,
categoryMask,
qualityScoresList,
BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
}