From d2780e4251fe2b27561f2dc7ede43579db842fa8 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Wed, 29 Mar 2023 21:06:50 +0530 Subject: [PATCH] Added swift image classifier test for quantized model --- .../ImageClassifierTests.swift | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift index 62bf0d487..837319d6f 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift +++ b/mediapipe/tasks/ios/test/vision/image_classifier/ImageClassifierTests.swift @@ -28,6 +28,10 @@ class ImageClassifierTests: XCTestCase { forResource: "mobilenet_v2_1.0_224", ofType: "tflite") + static let quantizedModelPath = bundle.path( + forResource: "mobilenet_v1_0.25_224_quant", + ofType: "tflite") + static let burgerImage = FileInfo(name: "burger", type: "jpg") static let burgerRotatedImage = FileInfo(name: "burger_rotated", type: "jpg") static let multiObjectsImage = FileInfo(name: "multi_objects", type: "jpg") @@ -258,6 +262,30 @@ class ImageClassifierTests: XCTestCase { ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel) } + func testClassifyWithQuantizedModelSucceeds() throws { + + let imageClassifierOptions = + try XCTUnwrap( + imageClassifierOptionsWithModelPath( + ImageClassifierTests.quantizedModelPath)) + + let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) + + let expectedCategories = [ + ResultCategory( + index: 934, + score: 0.972656, + categoryName: "cheeseburger", + displayName: nil), + ] + + try assertResultsForClassifyImageWithFileInfo( + ImageClassifierTests.burgerImage, + usingImageClassifier: imageClassifier, + hasCategoryCount: ImageClassifierTests.mobileNetCategoriesCount, + andCategories: expectedCategories) + } + func testClassifyWithScoreThresholdSucceeds() throws { let imageClassifierOptions =