Internal change

PiperOrigin-RevId: 522275233
This commit is contained in:
MediaPipe Team 2023-04-06 01:33:23 -07:00 committed by Copybara-Service
parent 22186299c4
commit 8d8ab9a972
5 changed files with 46 additions and 25 deletions

View File

@ -79,7 +79,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName, normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return runner.process(inputPackets); return runner.process(inputPackets);
} }
@ -105,7 +107,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName, normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
@ -131,7 +135,9 @@ public class BaseVisionTaskApi implements AutoCloseable {
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName, normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
@ -146,16 +152,30 @@ public class BaseVisionTaskApi implements AutoCloseable {
* message. * message.
*/ */
protected static NormalizedRect convertToNormalizedRect( protected static NormalizedRect convertToNormalizedRect(
ImageProcessingOptions imageProcessingOptions) { ImageProcessingOptions imageProcessingOptions, MPImage image) {
RectF regionOfInterest = RectF regionOfInterest =
imageProcessingOptions.regionOfInterest().isPresent() imageProcessingOptions.regionOfInterest().isPresent()
? imageProcessingOptions.regionOfInterest().get() ? imageProcessingOptions.regionOfInterest().get()
: new RectF(0, 0, 1, 1); : new RectF(0, 0, 1, 1);
// For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or
// height by the image width or height, repectively.
// - then rotates this by denormalized rect by the provided rotation, and
// uses this for cropping,
// - then finally rotates this back.
boolean requiresSwap = imageProcessingOptions.rotationDegrees() % 180 != 0;
return NormalizedRect.newBuilder() return NormalizedRect.newBuilder()
.setXCenter(regionOfInterest.centerX()) .setXCenter(regionOfInterest.centerX())
.setYCenter(regionOfInterest.centerY()) .setYCenter(regionOfInterest.centerY())
.setWidth(regionOfInterest.width()) .setWidth(
.setHeight(regionOfInterest.height()) requiresSwap
? regionOfInterest.height() * image.getHeight() / image.getWidth()
: regionOfInterest.width())
.setHeight(
requiresSwap
? regionOfInterest.width() * image.getWidth() / image.getHeight()
: regionOfInterest.height())
// Convert to radians anti-clockwise. // Convert to radians anti-clockwise.
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f) .setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
.build(); .build();

View File

@ -550,7 +550,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData)); inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData));
inputPackets.put( inputPackets.put(
NORM_RECT_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner
.getPacketCreator()
.createProto(convertToNormalizedRect(imageProcessingOptions, image)));
return (ImageSegmenterResult) runner.process(inputPackets); return (ImageSegmenterResult) runner.process(inputPackets);
} }
} }

View File

@ -265,9 +265,9 @@ public class ImageClassifierTest {
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
Category.create(0.6390683f, 934, "cheeseburger", ""), Category.create(0.75369555f, 934, "cheeseburger", ""),
Category.create(0.0495407f, 963, "meat loaf", ""), Category.create(0.029219573f, 925, "guacamole", ""),
Category.create(0.0469720f, 925, "guacamole", ""))); Category.create(0.028840661f, 932, "bagel", "")));
} }
@Test @Test
@ -279,8 +279,8 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// RectF around the chair. // RectF around the soccer ball.
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); RectF roi = new RectF(0.2655f, 0.45f, 0.6925f, 0.614f);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
ImageClassifierResult results = ImageClassifierResult results =
@ -289,7 +289,7 @@ public class ImageClassifierTest {
assertHasOneHead(results); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); results, Arrays.asList(Category.create(0.99730396f, 806, "soccer ball", "")));
} }
} }

View File

@ -188,7 +188,7 @@ public class ImageEmbedderTest {
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
result.embeddingResult().embeddings().get(0), result.embeddingResult().embeddings().get(0),
resultRotated.embeddingResult().embeddings().get(0)); resultRotated.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.982316669f);
} }
@Test @Test
@ -212,7 +212,7 @@ public class ImageEmbedderTest {
ImageEmbedder.cosineSimilarity( ImageEmbedder.cosineSimilarity(
resultRoiRotated.embeddingResult().embeddings().get(0), resultRoiRotated.embeddingResult().embeddings().get(0),
resultCrop.embeddingResult().embeddings().get(0)); resultCrop.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.9745944861f);
} }
} }

View File

@ -236,7 +236,6 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder() ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(1) .setMaxResults(1)
.setCategoryAllowlist(Arrays.asList("cat"))
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -246,7 +245,7 @@ public class ObjectDetectorTest {
objectDetector.detect( objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f); assertContainsOnlyCat(results, new RectF(0.0f, 608.0f, 439.0f, 995.0f), 0.69921875f);
} }
@Test @Test