Removed failing test for live stream mode classification

This commit is contained in:
Prianka Liz Kariat 2023-03-28 23:13:48 +05:30
parent 0d4c365b78
commit 7ce9e879df
3 changed files with 62 additions and 148 deletions

View File

@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <iostream>
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -262,6 +263,7 @@ class FlowLimiterCalculator : public CalculatorBase {
std::deque<Timestamp> frames_in_flight_; std::deque<Timestamp> frames_in_flight_;
std::map<Timestamp, bool> allowed_; std::map<Timestamp, bool> allowed_;
}; };
REGISTER_CALCULATOR(FlowLimiterCalculator); REGISTER_CALCULATOR(FlowLimiterCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -107,7 +107,8 @@ class ImageClassifierTests: XCTestCase {
categoryArray.count, categoryArray.count,
expectedCategoryArray.count) expectedCategoryArray.count)
for (index, (category, expectedCategory)) in zip(categoryArray, expectedCategoryArray) for (index, (category, expectedCategory)) in
zip(categoryArray, expectedCategoryArray)
.enumerated() .enumerated()
{ {
assertCategoriesAreEqual( assertCategoriesAreEqual(
@ -120,8 +121,22 @@ class ImageClassifierTests: XCTestCase {
func assertImageClassifierResultHasOneHead( func assertImageClassifierResultHasOneHead(
_ imageClassifierResult: ImageClassifierResult _ imageClassifierResult: ImageClassifierResult
) { ) {
XCTAssertEqual(imageClassifierResult.classificationResult.classifications.count, 1) XCTAssertEqual(
XCTAssertEqual(imageClassifierResult.classificationResult.classifications[0].headIndex, 0) imageClassifierResult.classificationResult.classifications.count,
1)
XCTAssertEqual(
imageClassifierResult.classificationResult.classifications[0].headIndex,
0)
}
func imageWithFileInfo(_ fileInfo: FileInfo) throws -> MPImage {
let mpImage = try XCTUnwrap(
MPImage.imageFromBundle(
withClass: type(of: self),
filename: fileInfo.name,
type: fileInfo.type))
return mpImage
} }
func imageClassifierOptionsWithModelPath( func imageClassifierOptionsWithModelPath(
@ -155,7 +170,8 @@ class ImageClassifierTests: XCTestCase {
andCategories expectedCategories: [ResultCategory] andCategories expectedCategories: [ResultCategory]
) throws { ) throws {
assertImageClassifierResultHasOneHead(imageClassifierResult) assertImageClassifierResultHasOneHead(imageClassifierResult)
let categories = imageClassifierResult.classificationResult.classifications[0].categories let categories =
imageClassifierResult.classificationResult.classifications[0].categories
XCTAssertEqual(categories.count, expectedCategoryCount) XCTAssertEqual(categories.count, expectedCategoryCount)
assertEqualCategoryArrays( assertEqualCategoryArrays(
@ -188,10 +204,7 @@ class ImageClassifierTests: XCTestCase {
andCategories expectedCategories: [ResultCategory] andCategories expectedCategories: [ResultCategory]
) throws { ) throws {
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(fileInfo))
withClass: type(of: self),
filename: fileInfo.name,
type: fileInfo.type))
try assertResultsForClassifyImage( try assertResultsForClassifyImage(
mpImage, mpImage,
@ -201,19 +214,6 @@ class ImageClassifierTests: XCTestCase {
) )
} }
// func testCreateImageClassifierWithInvalidMaxResultsFails() throws {
// let textClassifierOptions =
// try XCTUnwrap(
// textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
// textClassifierOptions.maxResults = 0
// assertCreateTextClassifierThrowsError(
// textClassifierOptions: textClassifierOptions,
// expectedErrorDescription: """
// INVALID_ARGUMENT: Invalid `max_results` option: value must be != 0.
// """)
// }
func testCreateImageClassifierWithCategoryAllowlistAndDenylistFails() throws { func testCreateImageClassifierWithCategoryAllowlistAndDenylistFails() throws {
let imageClassifierOptions = let imageClassifierOptions =
@ -254,7 +254,8 @@ class ImageClassifierTests: XCTestCase {
ImageClassifierTests.burgerImage, ImageClassifierTests.burgerImage,
usingImageClassifier: imageClassifier, usingImageClassifier: imageClassifier,
hasCategoryCount: ImageClassifierTests.mobileNetCategoriesCount, hasCategoryCount: ImageClassifierTests.mobileNetCategoriesCount,
andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel) andCategories:
ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel)
} }
func testClassifyWithScoreThresholdSucceeds() throws { func testClassifyWithScoreThresholdSucceeds() throws {
@ -264,7 +265,8 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath))
imageClassifierOptions.scoreThreshold = 0.25 imageClassifierOptions.scoreThreshold = 0.25
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let expectedCategories = [ let expectedCategories = [
ResultCategory( ResultCategory(
@ -286,9 +288,11 @@ class ImageClassifierTests: XCTestCase {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath))
imageClassifierOptions.categoryAllowlist = ["cheeseburger", "guacamole", "meat loaf"] imageClassifierOptions.categoryAllowlist =
["cheeseburger", "guacamole", "meat loaf"]
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let expectedCategories = [ let expectedCategories = [
ResultCategory( ResultCategory(
@ -319,13 +323,15 @@ class ImageClassifierTests: XCTestCase {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
imageClassifierOptions.categoryDenylist = ["bagel"] imageClassifierOptions.categoryDenylist = ["bagel"]
let maxResults = 3; let maxResults = 3;
imageClassifierOptions.maxResults = maxResults; imageClassifierOptions.maxResults = maxResults;
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let expectedCategories = [ let expectedCategories = [
ResultCategory( ResultCategory(
@ -356,18 +362,17 @@ class ImageClassifierTests: XCTestCase {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
let maxResults = 1; let maxResults = 1;
imageClassifierOptions.maxResults = maxResults; imageClassifierOptions.maxResults = maxResults;
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(ImageClassifierTests.multiObjectsImage))
withClass: type(of: self),
filename: ImageClassifierTests.multiObjectsImage.name,
type: ImageClassifierTests.multiObjectsImage.type))
let imageClassifierResult = try XCTUnwrap( let imageClassifierResult = try XCTUnwrap(
imageClassifier.classify( imageClassifier.classify(
@ -397,12 +402,14 @@ class ImageClassifierTests: XCTestCase {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
let maxResults = 3; let maxResults = 3;
imageClassifierOptions.maxResults = maxResults; imageClassifierOptions.maxResults = maxResults;
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let expectedCategories = [ let expectedCategories = [
ResultCategory( ResultCategory(
@ -440,12 +447,14 @@ class ImageClassifierTests: XCTestCase {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
imageClassifierOptionsWithModelPath(ImageClassifierTests.floatModelPath)) imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
let maxResults = 3; let maxResults = 3;
imageClassifierOptions.maxResults = maxResults; imageClassifierOptions.maxResults = maxResults;
let imageClassifier = try XCTUnwrap(ImageClassifier(options: imageClassifierOptions)) let imageClassifier = try XCTUnwrap(ImageClassifier(
options: imageClassifierOptions))
let expectedCategories = [ let expectedCategories = [
ResultCategory( ResultCategory(
@ -488,7 +497,9 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptionsWithModelPath( imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath)) ImageClassifierTests.floatModelPath))
imageClassifierOptions.runningMode = runningMode imageClassifierOptions.runningMode = runningMode
imageClassifierOptions.completion = {(result: ImageClassifierResult?, error: Error?) -> () in imageClassifierOptions.completion = {(
result: ImageClassifierResult?,
error: Error?) -> () in
} }
assertCreateImageClassifierThrowsError( assertCreateImageClassifierThrowsError(
@ -500,7 +511,8 @@ class ImageClassifierTests: XCTestCase {
} }
} }
func testImageClassifierFailsWithMissingResultListenerInLiveStreamMode() throws { func testImageClassifierFailsWithMissingResultListenerInLiveStreamMode()
throws {
let imageClassifierOptions = let imageClassifierOptions =
try XCTUnwrap( try XCTUnwrap(
@ -527,10 +539,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions)) imageClassifierOptions))
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(ImageClassifierTests.multiObjectsImage))
withClass: type(of: self),
filename: ImageClassifierTests.multiObjectsRotatedImage.name,
type: ImageClassifierTests.multiObjectsRotatedImage.type))
do { do {
try imageClassifier.classifyAsync( try imageClassifier.classifyAsync(
@ -554,7 +563,8 @@ class ImageClassifierTests: XCTestCase {
assertEqualErrorDescriptions( assertEqualErrorDescriptions(
error, error,
expectedLocalizedDescription: """ expectedLocalizedDescription: """
The vision task is not initialized with video mode. Current Running Mode: Image The vision task is not initialized with video mode. Current Running \
Mode: Image
""") """)
} }
} }
@ -572,10 +582,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions)) imageClassifierOptions))
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(ImageClassifierTests.multiObjectsImage))
withClass: type(of: self),
filename: ImageClassifierTests.multiObjectsRotatedImage.name,
type: ImageClassifierTests.multiObjectsRotatedImage.type))
do { do {
try imageClassifier.classifyAsync( try imageClassifier.classifyAsync(
@ -620,10 +627,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions)) imageClassifierOptions))
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(ImageClassifierTests.multiObjectsImage))
withClass: type(of: self),
filename: ImageClassifierTests.multiObjectsRotatedImage.name,
type: ImageClassifierTests.multiObjectsRotatedImage.type))
do { do {
let imagClassifierResult = try imageClassifier.classify( let imagClassifierResult = try imageClassifier.classify(
@ -668,10 +672,7 @@ class ImageClassifierTests: XCTestCase {
imageClassifierOptions)) imageClassifierOptions))
let mpImage = try XCTUnwrap( let mpImage = try XCTUnwrap(
MPImage.imageFromBundle( imageWithFileInfo(ImageClassifierTests.burgerImage))
withClass: type(of: self),
filename: ImageClassifierTests.burgerImage.name,
type: ImageClassifierTests.burgerImage.type))
for i in 0..<3 { for i in 0..<3 {
let imageClassifierResult = try XCTUnwrap( let imageClassifierResult = try XCTUnwrap(
@ -681,99 +682,9 @@ class ImageClassifierTests: XCTestCase {
try assertImageClassifierResult( try assertImageClassifierResult(
imageClassifierResult, imageClassifierResult,
hasCategoryCount: maxResults, hasCategoryCount: maxResults,
andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel andCategories:
ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel
) )
} }
} }
func testClassifyWithLiveStreamModeSucceeds() throws {
let imageClassifierOptions =
try XCTUnwrap(
imageClassifierOptionsWithModelPath(
ImageClassifierTests.floatModelPath))
imageClassifierOptions.runningMode = .liveStream
let maxResults = 3
imageClassifierOptions.maxResults = maxResults
let expectation = expectation(description: "liveStreamClassify")
imageClassifierOptions.completion = {(result: ImageClassifierResult?, error: Error?) -> () in
do {
try self.assertImageClassifierResult(
try XCTUnwrap(result),
hasCategoryCount: maxResults,
andCategories: ImageClassifierTests.expectedResultsClassifyBurgerImageWithFloatModel)
}
catch {
}
expectation.fulfill()
}
let imageClassifier = try XCTUnwrap(ImageClassifier(options:
imageClassifierOptions))
let mpImage = try XCTUnwrap(
MPImage.imageFromBundle(
withClass: type(of: self),
filename: ImageClassifierTests.burgerImage.name,
type: ImageClassifierTests.burgerImage.type))
for i in 0..<3 {
XCTAssertNoThrow(
try imageClassifier.classifyAsync(
image: mpImage,
timestampMs: i))
}
wait(for: [expectation], timeout: 10)
}
// func testClassifyWithMaxResultsSucceeds() throws {
// let textClassifierOptions =
// try XCTUnwrap(
// textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
// textClassifierOptions.maxResults = 1
// let textClassifier =
// try XCTUnwrap(TextClassifier(options: textClassifierOptions))
// try assertResultsForClassify(
// text: TextClassifierTests.negativeText,
// using: textClassifier,
// equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
// }
// func testClassifyWithCategoryAllowlistSucceeds() throws {
// let textClassifierOptions =
// try XCTUnwrap(
// textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
// textClassifierOptions.categoryAllowlist = ["negative"]
// let textClassifier =
// try XCTUnwrap(TextClassifier(options: textClassifierOptions))
// try assertResultsForClassify(
// text: TextClassifierTests.negativeText,
// using: textClassifier,
// equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
// }
// func testClassifyWithCategoryDenylistSucceeds() throws {
// let textClassifierOptions =
// try XCTUnwrap(
// textClassifierOptionsWithModelPath(TextClassifierTests.bertModelPath))
// textClassifierOptions.categoryDenylist = ["positive"]
// let textClassifier =
// try XCTUnwrap(TextClassifier(options: textClassifierOptions))
// try assertResultsForClassify(
// text: TextClassifierTests.negativeText,
// using: textClassifier,
// equals: TextClassifierTests.bertNegativeTextResultsForEdgeTestCases)
// }
} }

View File

@ -50,6 +50,7 @@ objc_library(
deps = [ deps = [
":MPPImageClassifierOptions", ":MPPImageClassifierOptions",
":MPPImageClassifierResult", ":MPPImageClassifierResult",
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",