internal update
PiperOrigin-RevId: 533197055
This commit is contained in:
parent
a1755044ea
commit
c248525eeb
|
@ -291,8 +291,11 @@ class TensorsToSegmentationCalculator : public Node {
|
|||
static constexpr Output<Image>::Multiple kConfidenceMaskOut{
|
||||
"CONFIDENCE_MASK"};
|
||||
static constexpr Output<Image>::Optional kCategoryMaskOut{"CATEGORY_MASK"};
|
||||
static constexpr Output<std::vector<float>>::Optional kQualityScoresOut{
|
||||
"QUALITY_SCORES"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kOutputSizeIn, kSegmentationOut,
|
||||
kConfidenceMaskOut, kCategoryMaskOut);
|
||||
kConfidenceMaskOut, kCategoryMaskOut,
|
||||
kQualityScoresOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
|
@ -345,12 +348,33 @@ absl::Status TensorsToSegmentationCalculator::Open(
|
|||
|
||||
absl::Status TensorsToSegmentationCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
RET_CHECK_EQ(kTensorsIn(cc).Get().size(), 1)
|
||||
<< "Expect a vector of single Tensor.";
|
||||
const auto& input_tensor = kTensorsIn(cc).Get()[0];
|
||||
const auto& input_tensors = kTensorsIn(cc).Get();
|
||||
if (input_tensors.size() != 1 && input_tensors.size() != 2) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Expect input tensor vector of size 1 or 2.");
|
||||
}
|
||||
const auto& input_tensor = *input_tensors.rbegin();
|
||||
ASSIGN_OR_RETURN(const Shape input_shape,
|
||||
GetImageLikeTensorShape(input_tensor));
|
||||
|
||||
// TODO: should use tensor signature to get the correct output
|
||||
// tensor.
|
||||
if (input_tensors.size() == 2) {
|
||||
const auto& quality_tensor = input_tensors[0];
|
||||
const float* quality_score_buffer =
|
||||
quality_tensor.GetCpuReadView().buffer<float>();
|
||||
const std::vector<float> quality_scores(
|
||||
quality_score_buffer,
|
||||
quality_score_buffer +
|
||||
(quality_tensor.bytes() / quality_tensor.element_size()));
|
||||
kQualityScoresOut(cc).Send(quality_scores);
|
||||
} else {
|
||||
// If the input_tensors don't contain quality scores, send the default
|
||||
// quality scores as 1.
|
||||
const std::vector<float> quality_scores(input_shape.channels, 1.0f);
|
||||
kQualityScoresOut(cc).Send(quality_scores);
|
||||
}
|
||||
|
||||
// Category mask does not require activation function.
|
||||
if (options_.segmenter_options().output_type() ==
|
||||
SegmenterOptions::CONFIDENCE_MASK &&
|
||||
|
|
|
@ -46,6 +46,8 @@ constexpr char kImageOutStreamName[] = "image_out";
|
|||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
@ -77,6 +79,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||
graph.Out(kCategoryMaskTag);
|
||||
}
|
||||
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||
graph.Out(kQualityScoresTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
|
@ -172,9 +176,13 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
category_mask =
|
||||
status_or_packets.value()[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
const std::vector<float>& quality_scores =
|
||||
status_or_packets.value()[kQualityScoresStreamName]
|
||||
.Get<std::vector<float>>();
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(
|
||||
{{confidence_masks, category_mask}}, image_packet.Get<Image>(),
|
||||
{{confidence_masks, category_mask, quality_scores}},
|
||||
image_packet.Get<Image>(),
|
||||
image_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
|
@ -227,7 +235,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::Segment(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
||||
|
@ -260,7 +270,9 @@ absl::StatusOr<ImageSegmenterResult> ImageSegmenter::SegmentForVideo(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
absl::Status ImageSegmenter::SegmentAsync(
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
@ -81,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
|||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||
constexpr char kQualityScoresTag[] = "QUALITY_SCORES";
|
||||
constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA";
|
||||
|
||||
// Struct holding the different output streams produced by the image segmenter
|
||||
|
@ -90,6 +92,7 @@ struct ImageSegmenterOutputs {
|
|||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||
std::optional<Source<Image>> category_mask;
|
||||
// The same as the input image, mainly used for live stream mode.
|
||||
std::optional<Source<std::vector<float>>> quality_scores;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -191,18 +194,11 @@ absl::Status ConfigureTensorsToSegmentationCalculator(
|
|||
"Segmentation tflite models are assumed to have a single subgraph.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
if (primary_subgraph->outputs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Segmentation tflite models are assumed to have a single output.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
*options->mutable_label_items(),
|
||||
GetLabelItemsIfAny(*metadata_extractor,
|
||||
*metadata_extractor->GetOutputTensorMetadata()->Get(0),
|
||||
GetLabelItemsIfAny(
|
||||
*metadata_extractor,
|
||||
**metadata_extractor->GetOutputTensorMetadata()->crbegin(),
|
||||
segmenter_option.display_names_locale()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -213,10 +209,16 @@ absl::StatusOr<const tflite::Tensor*> GetOutputTensor(
|
|||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
const auto* output_tensor =
|
||||
(*primary_subgraph->tensors())[(*primary_subgraph->outputs())[0]];
|
||||
(*primary_subgraph->tensors())[*(*primary_subgraph->outputs()).rbegin()];
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
uint32_t GetOutputTensorsSize(const core::ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
return primary_subgraph->outputs()->size();
|
||||
}
|
||||
|
||||
// Get the input tensor from the tflite model of given model resources.
|
||||
absl::StatusOr<const tflite::Tensor*> GetInputTensor(
|
||||
const core::ModelResources& model_resources) {
|
||||
|
@ -433,6 +435,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
*output_streams.category_mask >> graph[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
}
|
||||
if (output_streams.quality_scores) {
|
||||
*output_streams.quality_scores >>
|
||||
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||
}
|
||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -530,9 +536,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||
}
|
||||
}
|
||||
auto quality_scores =
|
||||
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||
return ImageSegmenterOutputs{/*segmented_masks=*/segmented_masks,
|
||||
/*confidence_masks=*/std::nullopt,
|
||||
/*category_mask=*/std::nullopt,
|
||||
/*quality_scores=*/quality_scores,
|
||||
/*image=*/image_and_tensors.image};
|
||||
} else {
|
||||
std::optional<std::vector<Source<Image>>> confidence_masks;
|
||||
|
@ -552,9 +561,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
if (output_category_mask_) {
|
||||
category_mask = tensor_to_images[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
auto quality_scores =
|
||||
tensor_to_images[Output<std::vector<float>>(kQualityScoresTag)];
|
||||
return ImageSegmenterOutputs{/*segmented_masks=*/std::nullopt,
|
||||
/*confidence_masks=*/confidence_masks,
|
||||
/*category_mask=*/category_mask,
|
||||
/*quality_scores=*/quality_scores,
|
||||
/*image=*/image_and_tensors.image};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,10 @@ struct ImageSegmenterResult {
|
|||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||
// the class which the pixel in the original image was predicted to belong to.
|
||||
std::optional<Image> category_mask;
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||
// the score of the category in the model outputs.
|
||||
std::vector<float> quality_scores;
|
||||
};
|
||||
|
||||
} // namespace image_segmenter
|
||||
|
|
|
@ -51,12 +51,14 @@ constexpr char kImageInStreamName[] = "image_in";
|
|||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kRoiStreamName[] = "roi_in";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kQualityScoresStreamName[] = "quality_scores";
|
||||
|
||||
constexpr absl::string_view kConfidenceMasksTag{"CONFIDENCE_MASKS"};
|
||||
constexpr absl::string_view kCategoryMaskTag{"CATEGORY_MASK"};
|
||||
constexpr absl::string_view kImageTag{"IMAGE"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||
|
||||
constexpr absl::string_view kSubgraphTypeName{
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"};
|
||||
|
@ -91,6 +93,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
task_subgraph.Out(kCategoryMaskTag).SetName(kCategoryMaskStreamName) >>
|
||||
graph.Out(kCategoryMaskTag);
|
||||
}
|
||||
task_subgraph.Out(kQualityScoresTag).SetName(kQualityScoresStreamName) >>
|
||||
graph.Out(kQualityScoresTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
|
@ -201,7 +205,9 @@ absl::StatusOr<ImageSegmenterResult> InteractiveSegmenter::Segment(
|
|||
if (output_category_mask_) {
|
||||
category_mask = output_packets[kCategoryMaskStreamName].Get<Image>();
|
||||
}
|
||||
return {{confidence_masks, category_mask}};
|
||||
const std::vector<float>& quality_scores =
|
||||
output_packets[kQualityScoresStreamName].Get<std::vector<float>>();
|
||||
return {{confidence_masks, category_mask, quality_scores}};
|
||||
}
|
||||
|
||||
} // namespace interactive_segmenter
|
||||
|
|
|
@ -58,6 +58,7 @@ constexpr absl::string_view kAlphaTag{"ALPHA"};
|
|||
constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||
constexpr absl::string_view kRoiTag{"ROI"};
|
||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||
|
||||
// Updates the graph to return `roi` stream which has same dimension as
|
||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||
|
@ -200,6 +201,8 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
|||
graph[Output<Image>(kCategoryMaskTag)];
|
||||
}
|
||||
}
|
||||
image_segmenter.Out(kQualityScoresTag) >>
|
||||
graph[Output<std::vector<float>>::Optional(kQualityScoresTag)];
|
||||
image_segmenter.Out(kImageTag) >> graph[Output<Image>(kImageTag)];
|
||||
|
||||
return graph.GetConfig();
|
||||
|
|
|
@ -115,6 +115,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
segmenterOptions.outputCategoryMask()
|
||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||
: -1;
|
||||
final int qualityScoresOutStreamIndex =
|
||||
getStreamIndex.apply(outputStreams, "QUALITY_SCORES:quality_scores");
|
||||
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
|
@ -128,6 +130,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
new ArrayList<>(),
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||
|
@ -182,9 +185,16 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
float[] qualityScores =
|
||||
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||
for (float score : qualityScores) {
|
||||
qualityScoresList.add(score);
|
||||
}
|
||||
return ImageSegmenterResult.create(
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
qualityScoresList,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
@ -592,8 +602,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
public abstract Builder setOutputCategoryMask(boolean value);
|
||||
|
||||
/**
|
||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||
* pipeline is done processing an image.
|
||||
* /** Sets an optional {@link ResultListener} to receive the segmentation results when the
|
||||
* graph pipeline is done processing an image.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||
|
|
|
@ -34,19 +34,30 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
|||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||
* category mask, where each pixel represents the class which the pixel in the original image
|
||||
* was predicted to belong to.
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* the category in the model outputs.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
public static ImageSegmenterResult create(
|
||||
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
|
||||
Optional<List<MPImage>> confidenceMasks,
|
||||
Optional<MPImage> categoryMask,
|
||||
List<Float> qualityScores,
|
||||
long timestampMs) {
|
||||
return new AutoValue_ImageSegmenterResult(
|
||||
confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
|
||||
confidenceMasks.map(Collections::unmodifiableList),
|
||||
categoryMask,
|
||||
Collections.unmodifiableList(qualityScores),
|
||||
timestampMs);
|
||||
}
|
||||
|
||||
public abstract Optional<List<MPImage>> confidenceMasks();
|
||||
|
||||
public abstract Optional<MPImage> categoryMask();
|
||||
|
||||
public abstract List<Float> qualityScores();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
||||
|
|
|
@ -127,6 +127,10 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||
}
|
||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
outputStreams.add("QUALITY_SCORES:quality_scores");
|
||||
final int qualityScoresOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
// TODO: add test for stream indices.
|
||||
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||
|
@ -142,6 +146,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
new ArrayList<>(),
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||
|
@ -199,9 +204,17 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
|
||||
float[] qualityScores =
|
||||
PacketGetter.getFloat32Vector(packets.get(qualityScoresOutStreamIndex));
|
||||
List<Float> qualityScoresList = new ArrayList<>(qualityScores.length);
|
||||
for (float score : qualityScores) {
|
||||
qualityScoresList.add(score);
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
qualityScoresList,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user