diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index fd7999610..b6e244d1f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -79,7 +79,9 @@ public class BaseVisionTaskApi implements AutoCloseable { inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( normRectStreamName, - runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); return runner.process(inputPackets); } @@ -105,7 +107,9 @@ public class BaseVisionTaskApi implements AutoCloseable { inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( normRectStreamName, - runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -131,7 +135,9 @@ public class BaseVisionTaskApi implements AutoCloseable { inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( normRectStreamName, - runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -146,16 +152,30 @@ public class BaseVisionTaskApi implements AutoCloseable { * message. */ protected static NormalizedRect convertToNormalizedRect( - ImageProcessingOptions imageProcessingOptions) { + ImageProcessingOptions imageProcessingOptions, MPImage image) { RectF regionOfInterest = imageProcessingOptions.regionOfInterest().isPresent() ? imageProcessingOptions.regionOfInterest().get() : new RectF(0, 0, 1, 1); + // For 90° and 270° rotations, we need to swap width and height. + // This is due to the internal behavior of ImageToTensorCalculator, which: + // - first denormalizes the provided rect by multiplying the rect width or + // height by the image width or height, repectively. + // - then rotates this by denormalized rect by the provided rotation, and + // uses this for cropping, + // - then finally rotates this back. + boolean requiresSwap = imageProcessingOptions.rotationDegrees() % 180 != 0; return NormalizedRect.newBuilder() .setXCenter(regionOfInterest.centerX()) .setYCenter(regionOfInterest.centerY()) - .setWidth(regionOfInterest.width()) - .setHeight(regionOfInterest.height()) + .setWidth( + requiresSwap + ? regionOfInterest.height() * image.getHeight() / image.getWidth() + : regionOfInterest.width()) + .setHeight( + requiresSwap + ? regionOfInterest.width() * image.getWidth() / image.getHeight() + : regionOfInterest.height()) // Convert to radians anti-clockwise. .setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f) .build(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 657716b6b..753cdc631 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -550,7 +550,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData)); inputPackets.put( NORM_RECT_IN_STREAM_NAME, - runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); + runner + .getPacketCreator() + .createProto(convertToNormalizedRect(imageProcessingOptions, image))); return (ImageSegmenterResult) runner.process(inputPackets); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index dac11bf02..3da4ea9b5 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -265,9 +265,9 @@ public class ImageClassifierTest { assertCategoriesAre( results, Arrays.asList( - Category.create(0.6390683f, 934, "cheeseburger", ""), - Category.create(0.0495407f, 963, "meat loaf", ""), - Category.create(0.0469720f, 925, "guacamole", ""))); + Category.create(0.75369555f, 934, "cheeseburger", ""), + Category.create(0.029219573f, 925, "guacamole", ""), + Category.create(0.028840661f, 932, "bagel", ""))); } @Test @@ -279,8 +279,8 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - // RectF around the chair. - RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); + // RectF around the soccer ball. + RectF roi = new RectF(0.2655f, 0.45f, 0.6925f, 0.614f); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); ImageClassifierResult results = @@ -289,7 +289,7 @@ public class ImageClassifierTest { assertHasOneHead(results); assertCategoriesAre( - results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); + results, Arrays.asList(Category.create(0.99730396f, 806, "soccer ball", ""))); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java index 8dec6f80b..d54586f6f 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -188,7 +188,7 @@ public class ImageEmbedderTest { ImageEmbedder.cosineSimilarity( result.embeddingResult().embeddings().get(0), resultRotated.embeddingResult().embeddings().get(0)); - assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.982316669f); } @Test @@ -212,7 +212,7 @@ public class ImageEmbedderTest { ImageEmbedder.cosineSimilarity( resultRoiRotated.embeddingResult().embeddings().get(0), resultCrop.embeddingResult().embeddings().get(0)); - assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.9745944861f); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index 102083d61..58bcaafda 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -236,7 +236,6 @@ public class ObjectDetectorTest { ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setMaxResults(1) - .setCategoryAllowlist(Arrays.asList("cat")) .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -246,7 +245,7 @@ public class ObjectDetectorTest { objectDetector.detect( getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); - assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f); + assertContainsOnlyCat(results, new RectF(0.0f, 608.0f, 439.0f, 995.0f), 0.69921875f); } @Test @@ -326,14 +325,14 @@ public class ObjectDetectorTest { MediaPipeException.class, () -> objectDetector.detectForVideo( - getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); + getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> objectDetector.detectAsync( - getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); + getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -357,7 +356,7 @@ public class ObjectDetectorTest { MediaPipeException.class, () -> objectDetector.detectAsync( - getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); + getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -383,7 +382,7 @@ public class ObjectDetectorTest { MediaPipeException.class, () -> objectDetector.detectForVideo( - getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); + getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -414,7 +413,7 @@ public class ObjectDetectorTest { for (int i = 0; i < 3; i++) { ObjectDetectionResult results = objectDetector.detectForVideo( - getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i); + getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } } @@ -435,11 +434,11 @@ public class ObjectDetectorTest { .build(); try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - objectDetector.detectAsync(image, /*timestampsMs=*/ 1); + objectDetector.detectAsync(image, /* timestampsMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0)); + () -> objectDetector.detectAsync(image, /* timestampsMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -463,7 +462,7 @@ public class ObjectDetectorTest { try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - objectDetector.detectAsync(image, /*timestampsMs=*/ i); + objectDetector.detectAsync(image, /* timestampsMs= */ i); } } }