Merge branch 'master' into ios-code-review-fixes

This commit is contained in:
Prianka Liz Kariat 2023-05-24 19:57:46 +05:30
commit e2e90dcac6
27 changed files with 155 additions and 75 deletions

View File

@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
return std::move(SetNoLogging()); return std::move(SetNoLogging());
} }
StatusBuilder::operator Status() const& { StatusBuilder::operator absl::Status() const& {
return StatusBuilder(*this).JoinMessageToStatus(); return StatusBuilder(*this).JoinMessageToStatus();
} }
StatusBuilder::operator Status() && { return JoinMessageToStatus(); } StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
absl::Status StatusBuilder::JoinMessageToStatus() { absl::Status StatusBuilder::JoinMessageToStatus() {
if (!impl_) { if (!impl_) {

View File

@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
return std::move(*this << msg); return std::move(*this << msg);
} }
operator Status() const&; operator absl::Status() const&;
operator Status() &&; operator absl::Status() &&;
absl::Status JoinMessageToStatus(); absl::Status JoinMessageToStatus();

View File

@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
lhs op## = rhs; \ lhs op## = rhs; \
return lhs; \ return lhs; \
} }
STRONG_INT_VS_STRONG_INT_BINARY_OP(+); STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
STRONG_INT_VS_STRONG_INT_BINARY_OP(-); STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
STRONG_INT_VS_STRONG_INT_BINARY_OP(&); STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
STRONG_INT_VS_STRONG_INT_BINARY_OP(|); STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
STRONG_INT_VS_STRONG_INT_BINARY_OP(^); STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP #undef STRONG_INT_VS_STRONG_INT_BINARY_OP
// Define operators that take one StrongInt and one native integer argument. // Define operators that take one StrongInt and one native integer argument.
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
rhs op## = lhs; \ rhs op## = lhs; \
return rhs; \ return rhs; \
} }
STRONG_INT_VS_NUMERIC_BINARY_OP(*); STRONG_INT_VS_NUMERIC_BINARY_OP(*)
NUMERIC_VS_STRONG_INT_BINARY_OP(*); NUMERIC_VS_STRONG_INT_BINARY_OP(*)
STRONG_INT_VS_NUMERIC_BINARY_OP(/); STRONG_INT_VS_NUMERIC_BINARY_OP(/)
STRONG_INT_VS_NUMERIC_BINARY_OP(%); STRONG_INT_VS_NUMERIC_BINARY_OP(%)
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators) STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators) STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
#undef STRONG_INT_VS_NUMERIC_BINARY_OP #undef STRONG_INT_VS_NUMERIC_BINARY_OP
#undef NUMERIC_VS_STRONG_INT_BINARY_OP #undef NUMERIC_VS_STRONG_INT_BINARY_OP
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
StrongInt<TagType, ValueType, ValidatorType> rhs) { \ StrongInt<TagType, ValueType, ValidatorType> rhs) { \
return lhs.value() op rhs.value(); \ return lhs.value() op rhs.value(); \
} }
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators) STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
#undef STRONG_INT_COMPARISON_OP #undef STRONG_INT_COMPARISON_OP
} // namespace intops } // namespace intops

View File

@ -57,7 +57,7 @@ namespace mediapipe {
// have underflow/overflow etc. This type is used internally by Timestamp // have underflow/overflow etc. This type is used internally by Timestamp
// and TimestampDiff. // and TimestampDiff.
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64, MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
mediapipe::intops::LogFatalOnError); mediapipe::intops::LogFatalOnError)
class TimestampDiff; class TimestampDiff;

View File

@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \ #define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeIdToMediaPipeTypeData, \ mediapipe::PacketTypeIdToMediaPipeTypeData, \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
(mediapipe::MediaPipeTypeData{ \ (mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); \ type_name, serialize_fn, deserialize_fn})); \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \ SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \ mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
(mediapipe::MediaPipeTypeData{ \ (mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \ mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \ mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); type_name, serialize_fn, deserialize_fn}));
// End define MEDIAPIPE_REGISTER_TYPE. // End define MEDIAPIPE_REGISTER_TYPE.

View File

@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Instantiates perceptual loss. """Instantiates perceptual loss.
Args: Args:
feature_weight: The weight coeffcients of multiple model extracted feature_weight: The weight coefficients of multiple model extracted
features used for calculating the perceptual loss. features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and loss_weight: The weight coefficients between `style_loss` and
`content_loss`. `content_loss`.

View File

@ -105,7 +105,7 @@ class FaceStylizer(object):
self._train_model(train_data=train_data, preprocessor=self._preprocessor) self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self): def _create_model(self):
"""Creates the componenets of face stylizer.""" """Creates the components of face stylizer."""
self._encoder = model_util.load_keras_model( self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path() constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
) )
@ -138,7 +138,7 @@ class FaceStylizer(object):
""" """
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor) train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing mulitple input style images. The # TODO: Support processing multiple input style images. The
# input style images are expected to have similar style. # input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label). # style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset)) style_sample = next(iter(train_dataset))

View File

@ -33,7 +33,7 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents // A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to. // the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask; std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Default to // The quality scores of the result masks, in the range of [0, 1]. Defaults to
// `1` if the model doesn't output quality scores. Each element corresponds to // `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs. // the score of the category in the model outputs.
std::vector<float> quality_scores; std::vector<float> quality_scores;

View File

@ -29,7 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \ #define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
XCTAssertNotNil(textEmbedderResult); \ XCTAssertNotNil(textEmbedderResult); \

View File

@ -34,7 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \ XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \ XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \ XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \ XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategoryArrays(categories, expectedCategories) \ #define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \ XCTAssertEqual(categories.count, expectedCategories.count); \
@ -668,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be // Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times. // invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called // An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. // `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if // If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`. // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is // Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting, // supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and

View File

@ -673,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be // Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times. // invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called // An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. // `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if // If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`. // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is // Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting, // supposed to be fulfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if // `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times. // expectation is fulfilled <= `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc] XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"]; initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1; expectation.expectedFulfillmentCount = iterationCount + 1;

View File

@ -17,7 +17,7 @@
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000; static const int kMicrosecondsPerMillisecond = 1000;
namespace { namespace {
using ClassificationResultProto = using ClassificationResultProto =
@ -29,9 +29,9 @@ using ::mediapipe::Packet;
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet { (const Packet &)packet {
// Even if packet does not validate a the expected type, you can safely access the timestamp. // Even if packet does not validate as the expected type, you can safely access the timestamp.
NSInteger timestampInMilliSeconds = NSInteger timestampInMilliSeconds =
(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); (NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) { if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s // MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
@ -48,7 +48,7 @@ using ::mediapipe::Packet;
return [[MPPImageClassifierResult alloc] return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult initWithClassificationResult:classificationResult
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicrosecondsPerMillisecond)];
} }
@end @end

View File

@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
// For 90° and 270° rotations, we need to swap width and height. // For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which: // This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or // - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively. // height by the image width or height, respectively.
// - then rotates this by denormalized rect by the provided rotation, and // - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping, // uses this for cropping,
// - then finally rotates this back. // - then finally rotates this back.

View File

@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image * category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to. * was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to * @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
* `1` if the model doesn't output quality scores. Each element corresponds to the score of * to `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs. * the category in the model outputs.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
*/ */

View File

@ -27,13 +27,14 @@ export declare interface Detection {
boundingBox?: BoundingBox; boundingBox?: BoundingBox;
/** /**
* Optional list of keypoints associated with the detection. Keypoints * List of keypoints associated with the detection. Keypoints represent
* represent interesting points related to the detection. For example, the * interesting points related to the detection. For example, the keypoints
* keypoints represent the eye, ear and mouth from face detection model. Or * represent the eye, ear and mouth from face detection model. Or in the
* in the template matching detection, e.g. KNIFT, they can represent the * template matching detection, e.g. KNIFT, they can represent the feature
* feature points for template matching. * points for template matching. Contains an empty list if no keypoints are
* detected.
*/ */
keypoints?: NormalizedKeypoint[]; keypoints: NormalizedKeypoint[];
} }
/** Detection results of a model. */ /** Detection results of a model. */

View File

@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });

View File

@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
const labels = source.getLabelList(); const labels = source.getLabelList();
const displayNames = source.getDisplayNameList(); const displayNames = source.getDisplayNameList();
const detection: Detection = {categories: []}; const detection: Detection = {categories: [], keypoints: []};
for (let i = 0; i < scores.length; i++) { for (let i = 0; i < scores.length; i++) {
detection.categories.push({ detection.categories.push({
score: scores[i], score: scores[i],
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
} }
if (source.getLocationData()?.getRelativeKeypointsList().length) { if (source.getLocationData()?.getRelativeKeypointsList().length) {
detection.keypoints = [];
for (const keypoint of for (const keypoint of
source.getLocationData()!.getRelativeKeypointsList()) { source.getLocationData()!.getRelativeKeypointsList()) {
detection.keypoints.push({ detection.keypoints.push({

View File

@ -191,7 +191,8 @@ describe('FaceDetector', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });

View File

@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided single image and returns the * Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be * result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the * used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* @param image An image to process. * @param image An image to process.
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided single image and returns the * Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be * result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the * used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* The 'imageProcessingOptions' parameter can be used to specify one or all * The 'imageProcessingOptions' parameter can be used to specify one or all
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
/** /**
* Performs face stylization on the provided video frame. This method creates * Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput * a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the * applications. Only use this method when the FaceStylizer is created with the
* video running mode. * video running mode.
* *
* The input frame can be of any size. It's required to provide the video * The input frame can be of any size. It's required to provide the video

View File

@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask'; const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGE_SEGMENTER_GRAPH = const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[]; private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private labels: string[] = []; private labels: string[] = [];
private userCallback?: ImageSegmenterCallback; private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided single image and returns the * Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method * should not be used in high-throughput applications. Only use this method
* when the ImageSegmenter is created with running mode `image`. * when the ImageSegmenter is created with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided single image and returns the * Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when * should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `image`. * the ImageSegmenter is created with running mode `image`.
* *
* @param image An image to process. * @param image An image to process.
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/** /**
* Performs image segmentation on the provided video frame and returns the * Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and * segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when * should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `video`. * the ImageSegmenter is created with running mode `video`.
* *
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
private reset(): void { private reset(): void {
this.categoryMask = undefined; this.categoryMask = undefined;
this.confidenceMasks = undefined; this.confidenceMasks = undefined;
this.qualityScores = undefined;
} }
private processResults(): ImageSegmenterResult|void { private processResults(): ImageSegmenterResult|void {
try { try {
const result = const result = new ImageSegmenterResult(
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask); this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(result); this.userCallback(result);
} else { } else {
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
}); });
} }
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
} }

View File

@ -30,7 +30,13 @@ export class ImageSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
readonly categoryMask?: MPMask) {} readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */ /** Frees the resources held by the category and confidence masks. */
close(): void { close(): void {

View File

@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
((images: WasmImage, timestamp: number) => void)|undefined; ((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener: confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
constructor() { constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null); super(createSpyWasmModule(), /* glCanvas= */ null);
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
expect(stream).toEqual('confidence_masks'); expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener; this.confidenceMasksListener = listener;
}); });
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
}); });
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
it('invokes listener after masks are available', async () => { it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false; let listenerCalled = false;
await imageSegmenter.setOptions( await imageSegmenter.setOptions(
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
], ],
1337); 1337);
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, () => { imageSegmenter.segment({} as HTMLImageElement, result => {
listenerCalled = true; listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve(); resolve();
}); });
}); });

View File

@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in'; const ROI_IN_STREAM = 'roi_in';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask'; const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGEA_SEGMENTER_GRAPH = const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false; const DEFAULT_OUTPUT_CATEGORY_MASK = false;
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
export class InteractiveSegmenter extends VisionTaskRunner { export class InteractiveSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[]; private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback; private userCallback?: InteractiveSegmenterCallback;
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private reset(): void { private reset(): void {
this.confidenceMasks = undefined; this.confidenceMasks = undefined;
this.categoryMask = undefined; this.categoryMask = undefined;
this.qualityScores = undefined;
} }
private processResults(): InteractiveSegmenterResult|void { private processResults(): InteractiveSegmenterResult|void {
try { try {
const result = new InteractiveSegmenterResult( const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask); this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(result); this.userCallback(result);
} else { } else {
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
}); });
} }
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
} }

View File

@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
readonly categoryMask?: MPMask) {} readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */ /** Frees the resources held by the category and confidence masks. */
close(): void { close(): void {

View File

@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
((images: WasmImage, timestamp: number) => void)|undefined; ((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener: confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto; lastRoi?: RenderDataProto;
constructor() { constructor() {
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
expect(stream).toEqual('confidence_masks'); expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener; this.confidenceMasksListener = listener;
}); });
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
}); });
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
}); });
}); });
it('invokes listener after masks are avaiblae', async () => { it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false; let listenerCalled = false;
await interactiveSegmenter.setOptions( await interactiveSegmenter.setOptions(
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
], ],
1337); 1337);
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
listenerCalled = true; listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve(); resolve();
}); });
}); });

View File

@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
categoryName: '', categoryName: '',
displayName: '', displayName: '',
}], }],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0} boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
}); });
}); });
}); });