ClassificationAggregationCalculator should fill in the timestamp_ms
field of the classification results in the stream mode.
Per user feedback, the consistency between the packet timestamp and the timestamp field of the classification result helps reducing the confusion. PiperOrigin-RevId: 501657922
This commit is contained in:
parent
1683d572ed
commit
8156da3418
|
@ -143,8 +143,9 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
|||
EXPECT_EQ(outputs.size(), 5);
|
||||
// Ignore last result, which operates on a too small chunk to return relevant
|
||||
// results.
|
||||
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
||||
for (int i = 0; i < outputs.size() - 1; i++) {
|
||||
EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
|
||||
EXPECT_EQ(outputs[i].timestamp_ms.value(), timestamps_ms[i]);
|
||||
EXPECT_EQ(outputs[i].classifications.size(), 1);
|
||||
EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
|
||||
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
|
||||
|
|
|
@ -188,6 +188,7 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
|
|||
*classifications->mutable_classification_list() =
|
||||
std::move(classification_lists[i]);
|
||||
}
|
||||
result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000);
|
||||
cached_classifications_.erase(cc->InputTimestamp().Value());
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -150,14 +150,15 @@ class ClassificationAggregationCalculatorTest
|
|||
CalculatorGraph calculator_graph_;
|
||||
};
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutAggregation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
|
||||
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
R"pb(timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 0 } }
|
||||
|
@ -169,7 +170,7 @@ TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
|||
})pb")));
|
||||
}
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithAggregation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK(Send(
|
||||
|
|
|
@ -534,6 +534,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
|||
// Validate results.
|
||||
EXPECT_THAT(results,
|
||||
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
classification_list {
|
||||
|
@ -567,6 +568,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
|||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
|
@ -603,6 +605,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
|||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
|
@ -646,6 +649,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
|||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "yamnet_classification"
|
||||
|
|
|
@ -61,7 +61,7 @@ def _generate_empty_results() -> ImageClassifierResult:
|
|||
timestamp_ms=0)
|
||||
|
||||
|
||||
def _generate_burger_results() -> ImageClassifierResult:
|
||||
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
|
@ -70,30 +70,36 @@ def _generate_burger_results() -> ImageClassifierResult:
|
|||
index=934,
|
||||
score=0.793959,
|
||||
display_name='',
|
||||
category_name='cheeseburger'),
|
||||
category_name='cheeseburger',
|
||||
),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.0273929,
|
||||
display_name='',
|
||||
category_name='bagel'),
|
||||
category_name='bagel',
|
||||
),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.0193408,
|
||||
display_name='',
|
||||
category_name='guacamole'),
|
||||
category_name='guacamole',
|
||||
),
|
||||
_Category(
|
||||
index=963,
|
||||
score=0.00632786,
|
||||
display_name='',
|
||||
category_name='meat loaf')
|
||||
category_name='meat loaf',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0)
|
||||
timestamp_ms=timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def _generate_soccer_ball_results() -> ImageClassifierResult:
|
||||
def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
|
@ -102,12 +108,15 @@ def _generate_soccer_ball_results() -> ImageClassifierResult:
|
|||
index=806,
|
||||
score=0.996527,
|
||||
display_name='',
|
||||
category_name='soccer ball')
|
||||
category_name='soccer ball',
|
||||
)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0)
|
||||
timestamp_ms=timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
|
@ -379,8 +388,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_burger_results().to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2(),
|
||||
)
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -398,8 +410,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, image_processing_options)
|
||||
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2(),
|
||||
)
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
|
@ -455,8 +470,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
score_threshold=threshold,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(self.test_image, timestamp)
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||
# Load the test image.
|
||||
|
@ -470,8 +484,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
|
||||
def check_result(result: ImageClassifierResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
test_utils.assert_proto_equals(self, result.to_pb2(),
|
||||
_generate_soccer_ball_results().to_pb2())
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
|
||||
)
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
|
@ -483,9 +498,7 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
max_results=1,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(test_image, timestamp,
|
||||
image_processing_options)
|
||||
classifier.classify_async(test_image, 100, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user