Merge branch 'master' into ios-code-review-fixes
This commit is contained in:
commit
e2e90dcac6
|
@ -68,11 +68,11 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
|
|||
return std::move(SetNoLogging());
|
||||
}
|
||||
|
||||
StatusBuilder::operator Status() const& {
|
||||
StatusBuilder::operator absl::Status() const& {
|
||||
return StatusBuilder(*this).JoinMessageToStatus();
|
||||
}
|
||||
|
||||
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
|
||||
StatusBuilder::operator absl::Status() && { return JoinMessageToStatus(); }
|
||||
|
||||
absl::Status StatusBuilder::JoinMessageToStatus() {
|
||||
if (!impl_) {
|
||||
|
|
|
@ -83,8 +83,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
|||
return std::move(*this << msg);
|
||||
}
|
||||
|
||||
operator Status() const&;
|
||||
operator Status() &&;
|
||||
operator absl::Status() const&;
|
||||
operator absl::Status() &&;
|
||||
|
||||
absl::Status JoinMessageToStatus();
|
||||
|
||||
|
|
|
@ -403,11 +403,11 @@ std::ostream &operator<<(std::ostream &os,
|
|||
lhs op## = rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(+)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(-)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(&)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(|)
|
||||
STRONG_INT_VS_STRONG_INT_BINARY_OP(^)
|
||||
#undef STRONG_INT_VS_STRONG_INT_BINARY_OP
|
||||
|
||||
// Define operators that take one StrongInt and one native integer argument.
|
||||
|
@ -431,12 +431,12 @@ STRONG_INT_VS_STRONG_INT_BINARY_OP(^);
|
|||
rhs op## = lhs; \
|
||||
return rhs; \
|
||||
}
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*);
|
||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%);
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(*)
|
||||
NUMERIC_VS_STRONG_INT_BINARY_OP(*)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(/)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(%)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(<<) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_VS_NUMERIC_BINARY_OP(>>) // NOLINT(whitespace/operators)
|
||||
#undef STRONG_INT_VS_NUMERIC_BINARY_OP
|
||||
#undef NUMERIC_VS_STRONG_INT_BINARY_OP
|
||||
|
||||
|
@ -447,12 +447,12 @@ STRONG_INT_VS_NUMERIC_BINARY_OP(>>); // NOLINT(whitespace/operators)
|
|||
StrongInt<TagType, ValueType, ValidatorType> rhs) { \
|
||||
return lhs.value() op rhs.value(); \
|
||||
}
|
||||
STRONG_INT_COMPARISON_OP(==); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(!=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>=); // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(==) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(!=) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(<=) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>) // NOLINT(whitespace/operators)
|
||||
STRONG_INT_COMPARISON_OP(>=) // NOLINT(whitespace/operators)
|
||||
#undef STRONG_INT_COMPARISON_OP
|
||||
|
||||
} // namespace intops
|
||||
|
|
|
@ -57,7 +57,7 @@ namespace mediapipe {
|
|||
// have underflow/overflow etc. This type is used internally by Timestamp
|
||||
// and TimestampDiff.
|
||||
MEDIAPIPE_DEFINE_SAFE_INT_TYPE(TimestampBaseType, int64,
|
||||
mediapipe::intops::LogFatalOnError);
|
||||
mediapipe::intops::LogFatalOnError)
|
||||
|
||||
class TimestampDiff;
|
||||
|
||||
|
|
|
@ -272,17 +272,20 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
|
|||
#define MEDIAPIPE_REGISTER_TYPE(type, type_name, serialize_fn, deserialize_fn) \
|
||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||
mediapipe::PacketTypeIdToMediaPipeTypeData, \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
(mediapipe::MediaPipeTypeData{ \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
type_name, serialize_fn, deserialize_fn})); \
|
||||
SET_MEDIAPIPE_TYPE_MAP_VALUE( \
|
||||
mediapipe::PacketTypeStringToMediaPipeTypeData, type_name, \
|
||||
(mediapipe::MediaPipeTypeData{ \
|
||||
mediapipe::tool::GetTypeHash< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>(), \
|
||||
mediapipe::TypeId::Of< \
|
||||
mediapipe::type_map_internal::ReflectType<void(type*)>::Type>() \
|
||||
.hash_code(), \
|
||||
type_name, serialize_fn, deserialize_fn}));
|
||||
// End define MEDIAPIPE_REGISTER_TYPE.
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ class PerceptualLoss(tf.keras.Model, metaclass=abc.ABCMeta):
|
|||
"""Instantiates perceptual loss.
|
||||
|
||||
Args:
|
||||
feature_weight: The weight coeffcients of multiple model extracted
|
||||
feature_weight: The weight coefficients of multiple model extracted
|
||||
features used for calculating the perceptual loss.
|
||||
loss_weight: The weight coefficients between `style_loss` and
|
||||
`content_loss`.
|
||||
|
|
|
@ -105,7 +105,7 @@ class FaceStylizer(object):
|
|||
self._train_model(train_data=train_data, preprocessor=self._preprocessor)
|
||||
|
||||
def _create_model(self):
|
||||
"""Creates the componenets of face stylizer."""
|
||||
"""Creates the components of face stylizer."""
|
||||
self._encoder = model_util.load_keras_model(
|
||||
constants.FACE_STYLIZER_ENCODER_MODEL_FILES.get_path()
|
||||
)
|
||||
|
@ -138,7 +138,7 @@ class FaceStylizer(object):
|
|||
"""
|
||||
train_dataset = train_data.gen_tf_dataset(preprocess=preprocessor)
|
||||
|
||||
# TODO: Support processing mulitple input style images. The
|
||||
# TODO: Support processing multiple input style images. The
|
||||
# input style images are expected to have similar style.
|
||||
# style_sample represents a tuple of (style_image, style_label).
|
||||
style_sample = next(iter(train_dataset))
|
||||
|
|
|
@ -33,7 +33,7 @@ struct ImageSegmenterResult {
|
|||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||
// the class which the pixel in the original image was predicted to belong to.
|
||||
std::optional<Image> category_mask;
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
|
||||
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||
// the score of the category in the model outputs.
|
||||
std::vector<float> quality_scores;
|
||||
|
|
|
@ -29,7 +29,7 @@ static const float kSimilarityDiffTolerance = 1e-4;
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
||||
XCTAssertNotNil(textEmbedderResult); \
|
||||
|
|
|
@ -34,7 +34,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription) \
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
#define AssertEqualCategoryArrays(categories, expectedCategories) \
|
||||
XCTAssertEqual(categories.count, expectedCategories.count); \
|
||||
|
@ -668,10 +668,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
// Because of flow limiting, we cannot ensure that the callback will be
|
||||
// invoked `iterationCount` times.
|
||||
// An normal expectation will fail if expectation.fullfill() is not called
|
||||
// An normal expectation will fail if expectation.fulfill() is not called
|
||||
// `expectation.expectedFulfillmentCount` times.
|
||||
// If `expectation.isInverted = true`, the test will only succeed if
|
||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
||||
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||
// Since in our case we cannot predict how many times the expectation is
|
||||
// supposed to be fullfilled setting,
|
||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||
|
|
|
@ -673,15 +673,15 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
|||
|
||||
// Because of flow limiting, we cannot ensure that the callback will be
|
||||
// invoked `iterationCount` times.
|
||||
// An normal expectation will fail if expectation.fullfill() is not called
|
||||
// An normal expectation will fail if expectation.fulfill() is not called
|
||||
// `expectation.expectedFulfillmentCount` times.
|
||||
// If `expectation.isInverted = true`, the test will only succeed if
|
||||
// expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
||||
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||
// Since in our case we cannot predict how many times the expectation is
|
||||
// supposed to be fullfilled setting,
|
||||
// supposed to be fulfilled setting,
|
||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||
// `expectation.isInverted = true` ensures that test succeeds if
|
||||
// expectation is fullfilled <= `iterationCount` times.
|
||||
// expectation is fulfilled <= `iterationCount` times.
|
||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||
expectation.expectedFulfillmentCount = iterationCount + 1;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
static const int kMicroSecondsPerMilliSecond = 1000;
|
||||
static const int kMicrosecondsPerMillisecond = 1000;
|
||||
|
||||
namespace {
|
||||
using ClassificationResultProto =
|
||||
|
@ -29,9 +29,9 @@ using ::mediapipe::Packet;
|
|||
|
||||
+ (nullable MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
|
||||
(const Packet &)packet {
|
||||
// Even if packet does not validate a the expected type, you can safely access the timestamp.
|
||||
// Even if packet does not validate as the expected type, you can safely access the timestamp.
|
||||
NSInteger timestampInMilliSeconds =
|
||||
(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);
|
||||
|
||||
if (!packet.ValidateAsType<ClassificationResultProto>().ok()) {
|
||||
// MPPClassificationResult's timestamp is populated from timestamp `ClassificationResultProto`'s
|
||||
|
@ -48,7 +48,7 @@ using ::mediapipe::Packet;
|
|||
return [[MPPImageClassifierResult alloc]
|
||||
initWithClassificationResult:classificationResult
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
kMicrosecondsPerMillisecond)];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -166,7 +166,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
// For 90° and 270° rotations, we need to swap width and height.
|
||||
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
||||
// - first denormalizes the provided rect by multiplying the rect width or
|
||||
// height by the image width or height, repectively.
|
||||
// height by the image width or height, respectively.
|
||||
// - then rotates this by denormalized rect by the provided rotation, and
|
||||
// uses this for cropping,
|
||||
// - then finally rotates this back.
|
||||
|
|
|
@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
|||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||
* category mask, where each pixel represents the class which the pixel in the original image
|
||||
* was predicted to belong to.
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
|
||||
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* the category in the model outputs.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
|
|
|
@ -27,13 +27,14 @@ export declare interface Detection {
|
|||
boundingBox?: BoundingBox;
|
||||
|
||||
/**
|
||||
* Optional list of keypoints associated with the detection. Keypoints
|
||||
* represent interesting points related to the detection. For example, the
|
||||
* keypoints represent the eye, ear and mouth from face detection model. Or
|
||||
* in the template matching detection, e.g. KNIFT, they can represent the
|
||||
* feature points for template matching.
|
||||
* List of keypoints associated with the detection. Keypoints represent
|
||||
* interesting points related to the detection. For example, the keypoints
|
||||
* represent the eye, ear and mouth from face detection model. Or in the
|
||||
* template matching detection, e.g. KNIFT, they can represent the feature
|
||||
* points for template matching. Contains an empty list if no keypoints are
|
||||
* detected.
|
||||
*/
|
||||
keypoints?: NormalizedKeypoint[];
|
||||
keypoints: NormalizedKeypoint[];
|
||||
}
|
||||
|
||||
/** Detection results of a model. */
|
||||
|
|
|
@ -85,7 +85,8 @@ describe('convertFromDetectionProto()', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -26,7 +26,7 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
|||
const labels = source.getLabelList();
|
||||
const displayNames = source.getDisplayNameList();
|
||||
|
||||
const detection: Detection = {categories: []};
|
||||
const detection: Detection = {categories: [], keypoints: []};
|
||||
for (let i = 0; i < scores.length; i++) {
|
||||
detection.categories.push({
|
||||
score: scores[i],
|
||||
|
@ -47,7 +47,6 @@ export function convertFromDetectionProto(source: DetectionProto): Detection {
|
|||
}
|
||||
|
||||
if (source.getLocationData()?.getRelativeKeypointsList().length) {
|
||||
detection.keypoints = [];
|
||||
for (const keypoint of
|
||||
source.getLocationData()!.getRelativeKeypointsList()) {
|
||||
detection.keypoints.push({
|
||||
|
|
|
@ -191,7 +191,8 @@ describe('FaceDetector', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -171,7 +171,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided single image and returns the
|
||||
* result. This method creates a copy of the resulting image and should not be
|
||||
* used in high-throughput applictions. Only use this method when the
|
||||
* used in high-throughput applications. Only use this method when the
|
||||
* FaceStylizer is created with the image running mode.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -182,7 +182,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided single image and returns the
|
||||
* result. This method creates a copy of the resulting image and should not be
|
||||
* used in high-throughput applictions. Only use this method when the
|
||||
* used in high-throughput applications. Only use this method when the
|
||||
* FaceStylizer is created with the image running mode.
|
||||
*
|
||||
* The 'imageProcessingOptions' parameter can be used to specify one or all
|
||||
|
@ -275,7 +275,7 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs face stylization on the provided video frame. This method creates
|
||||
* a copy of the resulting image and should not be used in high-throughput
|
||||
* applictions. Only use this method when the FaceStylizer is created with the
|
||||
* applications. Only use this method when the FaceStylizer is created with the
|
||||
* video running mode.
|
||||
*
|
||||
* The input frame can be of any size. It's required to provide the video
|
||||
|
|
|
@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
|
|||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGE_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
|||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private labels: string[] = [];
|
||||
private userCallback?: ImageSegmenterCallback;
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
|
@ -229,7 +231,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided single image and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-throughput applictions. Only use this method
|
||||
* should not be used in high-throughput applications. Only use this method
|
||||
* when the ImageSegmenter is created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -240,7 +242,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided single image and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-v applictions. Only use this method when
|
||||
* should not be used in high-v applications. Only use this method when
|
||||
* the ImageSegmenter is created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
|
@ -318,7 +320,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
/**
|
||||
* Performs image segmentation on the provided video frame and returns the
|
||||
* segmentation result. This method creates a copy of the resulting masks and
|
||||
* should not be used in high-v applictions. Only use this method when
|
||||
* should not be used in high-v applications. Only use this method when
|
||||
* the ImageSegmenter is created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
|
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.categoryMask = undefined;
|
||||
this.confidenceMasks = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): ImageSegmenterResult|void {
|
||||
try {
|
||||
const result =
|
||||
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
||||
const result = new ImageSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class ImageSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
|
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
|
|||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await imageSegmenter.setOptions(
|
||||
|
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, () => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
|||
const ROI_IN_STREAM = 'roi_in';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
|
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
|
|||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private userCallback?: InteractiveSegmenterCallback;
|
||||
|
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.confidenceMasks = undefined;
|
||||
this.categoryMask = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): InteractiveSegmenterResult|void {
|
||||
try {
|
||||
const result = new InteractiveSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask);
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
constructor() {
|
||||
|
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('invokes listener after masks are avaiblae', async () => {
|
||||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
|
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -210,7 +210,8 @@ describe('ObjectDetector', () => {
|
|||
categoryName: '',
|
||||
displayName: '',
|
||||
}],
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||
boundingBox: {originX: 0, originY: 0, width: 0, height: 0},
|
||||
keypoints: []
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user