ClassificationAggregationCalculator should fill in the timestamp_ms field of the classification results in the stream mode.

Per user feedback, the consistency between the packet timestamp and the timestamp field of the classification result helps reducing the confusion.

PiperOrigin-RevId: 501657922
This commit is contained in:
Jiuqiang Tang 2023-01-12 13:52:27 -08:00 committed by Copybara-Service
parent 1683d572ed
commit 8156da3418
5 changed files with 46 additions and 26 deletions

View File

@ -143,8 +143,9 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
EXPECT_EQ(outputs.size(), 5);
// Ignore last result, which operates on a too small chunk to return relevant
// results.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < outputs.size() - 1; i++) {
EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
EXPECT_EQ(outputs[i].timestamp_ms.value(), timestamps_ms[i]);
EXPECT_EQ(outputs[i].classifications.size(), 1);
EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");

View File

@ -188,6 +188,7 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000);
cached_classifications_.erase(cc->InputTimestamp().Value());
return result;
}

View File

@ -150,14 +150,15 @@ class ClassificationAggregationCalculatorTest
CalculatorGraph calculator_graph_;
};
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutAggregation) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
R"pb(timestamp_ms: 0,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
@ -169,7 +170,7 @@ TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
})pb")));
}
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithAggregation) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK(Send(

View File

@ -534,6 +534,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Validate results.
EXPECT_THAT(results,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
classification_list {
@ -567,6 +568,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "probability"
@ -603,6 +605,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "probability"
@ -646,6 +649,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "yamnet_classification"

View File

@ -61,7 +61,7 @@ def _generate_empty_results() -> ImageClassifierResult:
timestamp_ms=0)
def _generate_burger_results() -> ImageClassifierResult:
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
return ImageClassifierResult(
classifications=[
_Classifications(
@ -70,30 +70,36 @@ def _generate_burger_results() -> ImageClassifierResult:
index=934,
score=0.793959,
display_name='',
category_name='cheeseburger'),
category_name='cheeseburger',
),
_Category(
index=932,
score=0.0273929,
display_name='',
category_name='bagel'),
category_name='bagel',
),
_Category(
index=925,
score=0.0193408,
display_name='',
category_name='guacamole'),
category_name='guacamole',
),
_Category(
index=963,
score=0.00632786,
display_name='',
category_name='meat loaf')
category_name='meat loaf',
),
],
head_index=0,
head_name='probability')
head_name='probability',
)
],
timestamp_ms=0)
timestamp_ms=timestamp_ms,
)
def _generate_soccer_ball_results() -> ImageClassifierResult:
def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult:
return ImageClassifierResult(
classifications=[
_Classifications(
@ -102,12 +108,15 @@ def _generate_soccer_ball_results() -> ImageClassifierResult:
index=806,
score=0.996527,
display_name='',
category_name='soccer ball')
category_name='soccer ball',
)
],
head_index=0,
head_name='probability')
head_name='probability',
)
],
timestamp_ms=0)
timestamp_ms=timestamp_ms,
)
class ModelFileType(enum.Enum):
@ -379,8 +388,11 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
_generate_burger_results().to_pb2())
test_utils.assert_proto_equals(
self,
classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2(),
)
def test_classify_for_video_succeeds_with_region_of_interest(self):
options = _ImageClassifierOptions(
@ -398,8 +410,11 @@ class ImageClassifierTest(parameterized.TestCase):
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, image_processing_options)
test_utils.assert_proto_equals(self, classification_result.to_pb2(),
_generate_soccer_ball_results().to_pb2())
test_utils.assert_proto_equals(
self,
classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2(),
)
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
@ -455,8 +470,7 @@ class ImageClassifierTest(parameterized.TestCase):
score_threshold=threshold,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classifier.classify_async(self.test_image, timestamp)
classifier.classify_async(self.test_image, 0)
def test_classify_async_succeeds_with_region_of_interest(self):
# Load the test image.
@ -470,8 +484,9 @@ class ImageClassifierTest(parameterized.TestCase):
def check_result(result: ImageClassifierResult, output_image: _Image,
timestamp_ms: int):
test_utils.assert_proto_equals(self, result.to_pb2(),
_generate_soccer_ball_results().to_pb2())
test_utils.assert_proto_equals(
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
)
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms)
@ -483,9 +498,7 @@ class ImageClassifierTest(parameterized.TestCase):
max_results=1,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp,
image_processing_options)
classifier.classify_async(test_image, 100, image_processing_options)
if __name__ == '__main__':