Merge branch 'master' into ios-code-review-fixes

This commit is contained in:
Prianka Liz Kariat 2023-05-24 19:57:46 +05:30
commit e2e90dcac6
27 changed files with 155 additions and 75 deletions

View File

@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
return std::move(SetNoLogging());
}
StatusBuilder::operator Status() const& {
StatusBuilder::operator absl::Status() const& {
return StatusBuilder(*this).JoinMessageToStatus();
}
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
absl::Status StatusBuilder::JoinMessageToStatus() {
if (!impl_) {

View File

@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
return std::move(*this << msg);
}
operator Status() const&;
operator Status() &&;
operator absl::Status() const&;
operator absl::Status() &&;
absl::Status JoinMessageToStatus();

View File

@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
lhs op## = rhs; \
return lhs; \
}
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
// Define operators that take one StrongInt and one native integer argument.
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
rhs op## = lhs; \
return rhs; \
}
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
return lhs.value() op rhs.value(); \
}
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
#undef STRONG_INT_COMPARISON_OP
} // namespace intops

View File

@ -57,7 +57,7 @@ namespace mediapipe {
// have underflow/overflow etc. This type is used internally by Timestamp
// and TimestampDiff.
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
mediapipe::intops::LogFatalOnError);
mediapipe::intops::LogFatalOnError)
class TimestampDiff;

View File

@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeIdToMediaPipeTypeData, \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
(mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn})); \
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
(mediapipe::MediaPipeTypeData{ \
mediapipe::tool::GetTypeHash< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
mediapipe::TypeId::Of< \
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
.hash_code(), \
type_name, serialize_fn, deserialize_fn}));
// End define MEDIAPIPE_REGISTER_TYPE.

View File

@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
"""Instantiates perceptual loss.
Args:
feature_weight: The weight coeffcients of multiple model extracted
feature_weight: The weight coefficients of multiple model extracted
features used for calculating the perceptual loss.
loss_weight: The weight coefficients between `style_loss` and
`content_loss`.

View File

@ -105,7 +105,7 @@ class FaceStylizer(object):
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
def _create_model(self):
"""Creates the componenets of face stylizer."""
"""Creates the components of face stylizer."""
self._encoder = model_util.load_keras_model(
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
)
@ -138,7 +138,7 @@ class FaceStylizer(object):
"""
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
# TODO: Support processing mulitple input style images. The
# TODO: Support processing multiple input style images. The
# input style images are expected to have similar style.
# style_sample represents a tuple of (style_image, style_label).
style_sample = next(iter(train_dataset))

View File

@ -33,7 +33,7 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Default to
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
// `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs.
std::vector<float> quality_scores;

View File

@ -28,7 +28,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \

View File

@ -29,7 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
XCTAssertNotNil(textEmbedderResult); \

View File

@ -34,8 +34,8 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategoryArrays(categories, expectedCategories) \
XCTAssertEqual(categories.count, expectedCategories.count); \
for (int i = 0; i < categories.count; i++) { \
@ -668,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and

View File

@ -32,7 +32,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \
XCTAssertEqual(category.index, expectedCategory.index, \
@ -673,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// Because of flow limiting, we cannot ensure that the callback will be
// invoked `iterationCount` times.
// An normal expectation will fail if expectation.fullfill() is not called
// An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times.
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// supposed to be fulfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times.
// expectation is fulfilled <= `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;

View File

@ -17,7 +17,7 @@
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
static const int kMicrosecondsPerMillisecond = 1000;
namespace {
using ClassificationResultProto =
@ -29,9 +29,9 @@ using ::mediapipe::Packet;
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet {
// Even if packet does not validate a the expected type, you can safely access the timestamp.
// Even if packet does not validate as the expected type, you can safely access the timestamp.
NSInteger timestampInMilliSeconds =
(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
@ -48,7 +48,7 @@ using ::mediapipe::Packet;
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
kMicrosecondsPerMillisecond)];
}
@end

View File

@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
// For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively.
// height by the image width or height, respectively.
// - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping,
// - then finally rotates this back.

View File

@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs.
* @param timestampMs a timestamp for this result.
*/

View File

@ -27,13 +27,14 @@ export declare interface Detection {
boundingBox?: BoundingBox;
/**
* Optional list of keypoints associated with the detection. Keypoints
* represent interesting points related to the detection. For example, the
* keypoints represent the eye, ear and mouth from face detection model. Or
* in the template matching detection, e.g. KNIFT, they can represent the
* feature points for template matching.
* List of keypoints associated with the detection. Keypoints represent
* interesting points related to the detection. For example, the keypoints
* represent the eye, ear and mouth from face detection model. Or in the
* template matching detection, e.g. KNIFT, they can represent the feature
* points for template matching. Contains an empty list if no keypoints are
* detected.
*/
keypoints?: NormalizedKeypoint[];
keypoints: NormalizedKeypoint[];
}
/** Detection results of a model. */

View File

@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});

View File

@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
const labels = source.getLabelList();
const displayNames = source.getDisplayNameList();
const detection: Detection = {categories: []};
const detection: Detection = {categories: [], keypoints: []};
for (let i = 0; i < scores.length; i++) {
detection.categories.push({
score: scores[i],
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
}
if (source.getLocationData()?.getRelativeKeypointsList().length) {
detection.keypoints = [];
for (const keypoint of
source.getLocationData()!.getRelativeKeypointsList()) {
detection.keypoints.push({

View File

@ -191,7 +191,8 @@ describe('FaceDetector', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});

View File

@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* @param image An image to process.
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided single image and returns the
* result. This method creates a copy of the resulting image and should not be
* used in high-throughput applictions. Only use this method when the
* used in high-throughput applications. Only use this method when the
* FaceStylizer is created with the image running mode.
*
* The 'imageProcessingOptions' parameter can be used to specify one or all
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
/**
* Performs face stylization on the provided video frame. This method creates
* a copy of the resulting image and should not be used in high-throughput
* applictions. Only use this method when the FaceStylizer is created with the
* applications. Only use this method when the FaceStylizer is created with the
* video running mode.
*
* The input frame can be of any size. It's required to provide the video

View File

@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private labels: string[] = [];
private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applictions. Only use this method
* should not be used in high-throughput applications. Only use this method
* when the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided single image and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `image`.
*
* @param image An image to process.
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
/**
* Performs image segmentation on the provided video frame and returns the
* segmentation result. This method creates a copy of the resulting masks and
* should not be used in high-v applictions. Only use this method when
* should not be used in high-v applications. Only use this method when
* the ImageSegmenter is created with running mode `video`.
*
* @param videoFrame A video frame to process.
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
private reset(): void {
this.categoryMask = undefined;
this.confidenceMasks = undefined;
this.qualityScores = undefined;
}
private processResults(): ImageSegmenterResult|void {
try {
const result =
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
const result = new ImageSegmenterResult(
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class ImageSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await imageSegmenter.setOptions(
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, () => {
imageSegmenter.segment({} as HTMLImageElement, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});

View File

@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
export class InteractiveSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback;
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private reset(): void {
this.confidenceMasks = undefined;
this.categoryMask = undefined;
this.qualityScores = undefined;
}
private processResults(): InteractiveSegmenterResult|void {
try {
const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask);
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto;
constructor() {
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
});
});
it('invokes listener after masks are avaiblae', async () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await interactiveSegmenter.setOptions(
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});

View File

@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
keypoints: []
});
});
});