Add keys for the context that better match the featurelist for text.

PiperOrigin-RevId: 544430289
This commit is contained in:
MediaPipe Team 2023-06-29 12:27:49 -07:00 committed by Copybara-Service
parent 0bb4ee8941
commit 52cea59d41
5 changed files with 66 additions and 0 deletions

View File

@ -593,6 +593,8 @@ ground truth transcripts.
|-----|------|------------------------|-------------| |-----|------|------------------------|-------------|
|`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.| |`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.|
|`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.| |`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.|
|`text/context/token_id`|context int list|`set_text_context_token_id` / `SetTextContextTokenId`|Storage for large blocks of text in the context as token ids.|
|`text/context/embedding`|context float list|`set_text_context_embedding` / `SetTextContextEmbedding`|Storage for large blocks of text in the context as embeddings.|
|`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.| |`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.|
|`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.| |`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.|
|`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.| |`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.|

View File

@ -634,6 +634,10 @@ PREFIXED_IMAGE(InstanceSegmentation, kInstanceSegmentationPrefix);
const char kTextLanguageKey[] = "text/language"; const char kTextLanguageKey[] = "text/language";
// A large block of text that applies to the media. // A large block of text that applies to the media.
const char kTextContextContentKey[] = "text/context/content"; const char kTextContextContentKey[] = "text/context/content";
// A large block of text that applies to the media as token ids.
const char kTextContextTokenIdKey[] = "text/context/token_id";
// A large block of text that applies to the media as embeddings.
const char kTextContextEmbeddingKey[] = "text/context/embedding";
// Feature list keys: // Feature list keys:
// The text contents for a given time. // The text contents for a given time.
@ -651,6 +655,8 @@ const char kTextTokenIdKey[] = "text/token/id";
BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey); BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey);
BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey); BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey);
VECTOR_INT64_CONTEXT_FEATURE(TextContextTokenId, kTextContextTokenIdKey);
VECTOR_FLOAT_CONTEXT_FEATURE(TextContextEmbedding, kTextContextEmbeddingKey);
BYTES_FEATURE_LIST(TextContent, kTextContentKey); BYTES_FEATURE_LIST(TextContent, kTextContentKey);
INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey); INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey);
INT64_FEATURE_LIST(TextDuration, kTextDurationKey); INT64_FEATURE_LIST(TextDuration, kTextDurationKey);

View File

@ -601,6 +601,10 @@ _create_image_with_prefix("instance_segmentation", INSTANCE_SEGMENTATION_PREFIX)
TEXT_LANGUAGE_KEY = "text/language" TEXT_LANGUAGE_KEY = "text/language"
# A large block of text that applies to the media. # A large block of text that applies to the media.
TEXT_CONTEXT_CONTENT_KEY = "text/context/content" TEXT_CONTEXT_CONTENT_KEY = "text/context/content"
# A large block of text that applies to the media as token ids.
TEXT_CONTEXT_TOKEN_ID_KEY = "text/context/token_id"
# A large block of text that applies to the media as embeddings.
TEXT_CONTEXT_EMBEDDING_KEY = "text/context/embedding"
# The text contents for a given time. # The text contents for a given time.
TEXT_CONTENT_KEY = "text/content" TEXT_CONTENT_KEY = "text/content"
@ -619,6 +623,10 @@ msu.create_bytes_context_feature(
"text_language", TEXT_LANGUAGE_KEY, module_dict=globals()) "text_language", TEXT_LANGUAGE_KEY, module_dict=globals())
msu.create_bytes_context_feature( msu.create_bytes_context_feature(
"text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals()) "text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals())
msu.create_int_list_context_feature(
"text_context_token_id", TEXT_CONTEXT_TOKEN_ID_KEY, module_dict=globals())
msu.create_float_list_context_feature(
"text_context_embedding", TEXT_CONTEXT_EMBEDDING_KEY, module_dict=globals())
msu.create_bytes_feature_list( msu.create_bytes_feature_list(
"text_content", TEXT_CONTENT_KEY, module_dict=globals()) "text_content", TEXT_CONTENT_KEY, module_dict=globals())
msu.create_int_feature_list( msu.create_int_feature_list(

View File

@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector>
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -711,6 +712,30 @@ TEST(MediaSequenceTest, RoundTripTextContextContent) {
ASSERT_FALSE(HasTextContextContent(sequence)); ASSERT_FALSE(HasTextContextContent(sequence));
} }
TEST(MediaSequenceTest, RoundTripTextContextTokenId) {
tensorflow::SequenceExample sequence;
ASSERT_FALSE(HasTextContextTokenId(sequence));
std::vector<int64_t> vi = {47, 35};
SetTextContextTokenId(vi, &sequence);
ASSERT_TRUE(HasTextContextTokenId(sequence));
ASSERT_EQ(GetTextContextTokenId(sequence).size(), vi.size());
ASSERT_EQ(GetTextContextTokenId(sequence)[1], vi[1]);
ClearTextContextTokenId(&sequence);
ASSERT_FALSE(HasTextContextTokenId(sequence));
}
TEST(MediaSequenceTest, RoundTripTextContextEmbedding) {
tensorflow::SequenceExample sequence;
ASSERT_FALSE(HasTextContextEmbedding(sequence));
std::vector<float> vi = {47., 35.};
SetTextContextEmbedding(vi, &sequence);
ASSERT_TRUE(HasTextContextEmbedding(sequence));
ASSERT_EQ(GetTextContextEmbedding(sequence).size(), vi.size());
ASSERT_EQ(GetTextContextEmbedding(sequence)[1], vi[1]);
ClearTextContextEmbedding(&sequence);
ASSERT_FALSE(HasTextContextEmbedding(sequence));
}
TEST(MediaSequenceTest, RoundTripTextContent) { TEST(MediaSequenceTest, RoundTripTextContent) {
tensorflow::SequenceExample sequence; tensorflow::SequenceExample sequence;
std::vector<std::string> text = {"test", "again"}; std::vector<std::string> text = {"test", "again"};

View File

@ -129,6 +129,8 @@ class MediaSequenceTest(tf.test.TestCase):
ms.add_bbox_embedding_confidence((0.47, 0.49), example) ms.add_bbox_embedding_confidence((0.47, 0.49), example)
ms.set_text_language(b"test", example) ms.set_text_language(b"test", example)
ms.set_text_context_content(b"text", example) ms.set_text_context_content(b"text", example)
ms.set_text_context_token_id([47, 49], example)
ms.set_text_context_embedding([0.47, 0.49], example)
ms.add_text_content(b"one", example) ms.add_text_content(b"one", example)
ms.add_text_timestamp(47, example) ms.add_text_timestamp(47, example)
ms.add_text_confidence(0.47, example) ms.add_text_confidence(0.47, example)
@ -260,6 +262,29 @@ class MediaSequenceTest(tf.test.TestCase):
self.assertFalse(ms.has_feature_dimensions(example, "1")) self.assertFalse(ms.has_feature_dimensions(example, "1"))
self.assertFalse(ms.has_feature_dimensions(example, "2")) self.assertFalse(ms.has_feature_dimensions(example, "2"))
def test_text_context_round_trip(self):
example = tf.train.SequenceExample()
text_content = b"text content"
text_token_ids = np.array([1, 2, 3, 4])
text_embeddings = np.array([0.1, 0.2, 0.3, 0.4])
self.assertFalse(ms.has_text_context_embedding(example))
self.assertFalse(ms.has_text_context_token_id(example))
self.assertFalse(ms.has_text_context_content(example))
ms.set_text_context_content(text_content, example)
ms.set_text_context_token_id(text_token_ids, example)
ms.set_text_context_embedding(text_embeddings, example)
self.assertEqual(text_content, ms.get_text_context_content(example))
self.assertAllClose(text_token_ids, ms.get_text_context_token_id(example))
self.assertAllClose(text_embeddings, ms.get_text_context_embedding(example))
self.assertTrue(ms.has_text_context_embedding(example))
self.assertTrue(ms.has_text_context_token_id(example))
self.assertTrue(ms.has_text_context_content(example))
ms.clear_text_context_content(example)
ms.clear_text_context_token_id(example)
ms.clear_text_context_embedding(example)
self.assertFalse(ms.has_text_context_embedding(example))
self.assertFalse(ms.has_text_context_token_id(example))
self.assertFalse(ms.has_text_context_content(example))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()