Internal change

PiperOrigin-RevId: 522275233
This commit is contained in:
MediaPipe Team 2023-04-06 01:33:23 -07:00 committed by Copybara-Service
parent 22186299c4
commit 8d8ab9a972
5 changed files with 46 additions and 25 deletions

View File

@ -79,7 +79,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return runner.process(inputPackets);
}
@ -105,7 +107,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
@ -131,7 +135,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
@ -146,16 +152,30 @@ public class BaseVisionTaskApi implements AutoCloseable {
* message.
*/
protected static NormalizedRect convertToNormalizedRect(
ImageProcessingOptions imageProcessingOptions) {
ImageProcessingOptions imageProcessingOptions, MPImage image) {
RectF regionOfInterest =
imageProcessingOptions.regionOfInterest().isPresent()
? imageProcessingOptions.regionOfInterest().get()
: new RectF(0, 0, 1, 1);
// For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively.
// - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping,
// - then finally rotates this back.
boolean requiresSwap = imageProcessingOptions.rotationDegrees() % 180 != 0;
return NormalizedRect.newBuilder()
.setXCenter(regionOfInterest.centerX())
.setYCenter(regionOfInterest.centerY())
.setWidth(regionOfInterest.width())
.setHeight(regionOfInterest.height())
.setWidth(
requiresSwap
? regionOfInterest.height() * image.getHeight() / image.getWidth()
: regionOfInterest.width())
.setHeight(
requiresSwap
? regionOfInterest.width() * image.getWidth() / image.getHeight()
: regionOfInterest.height())
// Convert to radians anti-clockwise.
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
.build();

View File

@ -550,7 +550,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData));
inputPackets.put(
NORM_RECT_IN_STREAM_NAME,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return (ImageSegmenterResult) runner.process(inputPackets);
}
}

View File

@ -265,9 +265,9 @@ public class ImageClassifierTest {
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.6390683f, 934, "cheeseburger", ""),
Category.create(0.0495407f, 963, "meat loaf", ""),
Category.create(0.0469720f, 925, "guacamole", "")));
Category.create(0.75369555f, 934, "cheeseburger", ""),
Category.create(0.029219573f, 925, "guacamole", ""),
Category.create(0.028840661f, 932, "bagel", "")));
}
@Test
@ -279,8 +279,8 @@ public class ImageClassifierTest {
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// RectF around the chair.
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
// RectF around the soccer ball.
RectF roi = new RectF(0.2655f, 0.45f, 0.6925f, 0.614f);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
ImageClassifierResult results =
@ -289,7 +289,7 @@ public class ImageClassifierTest {
assertHasOneHead(results);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
results, Arrays.asList(Category.create(0.99730396f, 806, "soccer ball", "")));
}
}

View File

@ -188,7 +188,7 @@ public class ImageEmbedderTest {
ImageEmbedder.cosineSimilarity(
result.embeddingResult().embeddings().get(0),
resultRotated.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f);
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.982316669f);
}
@Test
@ -212,7 +212,7 @@ public class ImageEmbedderTest {
ImageEmbedder.cosineSimilarity(
resultRoiRotated.embeddingResult().embeddings().get(0),
resultCrop.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f);
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.9745944861f);
}
}

View File

@ -236,7 +236,6 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(1)
.setCategoryAllowlist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -246,7 +245,7 @@ public class ObjectDetectorTest {
objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f);
assertContainsOnlyCat(results, new RectF(0.0f, 608.0f, 439.0f, 995.0f), 0.69921875f);
}
@Test
@ -326,14 +325,14 @@ public class ObjectDetectorTest {
MediaPipeException.class,
() ->
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
objectDetector.detectAsync(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -357,7 +356,7 @@ public class ObjectDetectorTest {
MediaPipeException.class,
() ->
objectDetector.detectAsync(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -383,7 +382,7 @@ public class ObjectDetectorTest {
MediaPipeException.class,
() ->
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -414,7 +413,7 @@ public class ObjectDetectorTest {
for (int i = 0; i < 3; i++) {
ObjectDetectionResult results =
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i);
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
}
@ -435,11 +434,11 @@ public class ObjectDetectorTest {
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
objectDetector.detectAsync(image, /*timestampsMs=*/ 1);
objectDetector.detectAsync(image, /* timestampsMs= */ 1);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0));
() -> objectDetector.detectAsync(image, /* timestampsMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
@ -463,7 +462,7 @@ public class ObjectDetectorTest {
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
objectDetector.detectAsync(image, /*timestampsMs=*/ i);
objectDetector.detectAsync(image, /* timestampsMs= */ i);
}
}
}