Add keys for the context that better match the featurelist for text.
PiperOrigin-RevId: 544430289
This commit is contained in:
parent
0bb4ee8941
commit
52cea59d41
|
@ -593,6 +593,8 @@ ground truth transcripts.
|
|||
|-----|------|------------------------|-------------|
|
||||
|`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.|
|
||||
|`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.|
|
||||
|`text/context/token_id`|context int list|`set_text_context_token_id` / `SetTextContextTokenId`|Storage for large blocks of text in the context as token ids.|
|
||||
|`text/context/embedding`|context float list|`set_text_context_embedding` / `SetTextContextEmbedding`|Storage for large blocks of text in the context as embeddings.|
|
||||
|`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.|
|
||||
|`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.|
|
||||
|`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.|
|
||||
|
|
|
@ -634,6 +634,10 @@ PREFIXED_IMAGE(InstanceSegmentation, kInstanceSegmentationPrefix);
|
|||
const char kTextLanguageKey[] = "text/language";
|
||||
// A large block of text that applies to the media.
|
||||
const char kTextContextContentKey[] = "text/context/content";
|
||||
// A large block of text that applies to the media as token ids.
|
||||
const char kTextContextTokenIdKey[] = "text/context/token_id";
|
||||
// A large block of text that applies to the media as embeddings.
|
||||
const char kTextContextEmbeddingKey[] = "text/context/embedding";
|
||||
|
||||
// Feature list keys:
|
||||
// The text contents for a given time.
|
||||
|
@ -651,6 +655,8 @@ const char kTextTokenIdKey[] = "text/token/id";
|
|||
|
||||
BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey);
|
||||
BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey);
|
||||
VECTOR_INT64_CONTEXT_FEATURE(TextContextTokenId, kTextContextTokenIdKey);
|
||||
VECTOR_FLOAT_CONTEXT_FEATURE(TextContextEmbedding, kTextContextEmbeddingKey);
|
||||
BYTES_FEATURE_LIST(TextContent, kTextContentKey);
|
||||
INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey);
|
||||
INT64_FEATURE_LIST(TextDuration, kTextDurationKey);
|
||||
|
|
|
@ -601,6 +601,10 @@ _create_image_with_prefix("instance_segmentation", INSTANCE_SEGMENTATION_PREFIX)
|
|||
TEXT_LANGUAGE_KEY = "text/language"
|
||||
# A large block of text that applies to the media.
|
||||
TEXT_CONTEXT_CONTENT_KEY = "text/context/content"
|
||||
# A large block of text that applies to the media as token ids.
|
||||
TEXT_CONTEXT_TOKEN_ID_KEY = "text/context/token_id"
|
||||
# A large block of text that applies to the media as embeddings.
|
||||
TEXT_CONTEXT_EMBEDDING_KEY = "text/context/embedding"
|
||||
|
||||
# The text contents for a given time.
|
||||
TEXT_CONTENT_KEY = "text/content"
|
||||
|
@ -619,6 +623,10 @@ msu.create_bytes_context_feature(
|
|||
"text_language", TEXT_LANGUAGE_KEY, module_dict=globals())
|
||||
msu.create_bytes_context_feature(
|
||||
"text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals())
|
||||
msu.create_int_list_context_feature(
|
||||
"text_context_token_id", TEXT_CONTEXT_TOKEN_ID_KEY, module_dict=globals())
|
||||
msu.create_float_list_context_feature(
|
||||
"text_context_embedding", TEXT_CONTEXT_EMBEDDING_KEY, module_dict=globals())
|
||||
msu.create_bytes_feature_list(
|
||||
"text_content", TEXT_CONTENT_KEY, module_dict=globals())
|
||||
msu.create_int_feature_list(
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/location.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
|
@ -711,6 +712,30 @@ TEST(MediaSequenceTest, RoundTripTextContextContent) {
|
|||
ASSERT_FALSE(HasTextContextContent(sequence));
|
||||
}
|
||||
|
||||
TEST(MediaSequenceTest, RoundTripTextContextTokenId) {
|
||||
tensorflow::SequenceExample sequence;
|
||||
ASSERT_FALSE(HasTextContextTokenId(sequence));
|
||||
std::vector<int64_t> vi = {47, 35};
|
||||
SetTextContextTokenId(vi, &sequence);
|
||||
ASSERT_TRUE(HasTextContextTokenId(sequence));
|
||||
ASSERT_EQ(GetTextContextTokenId(sequence).size(), vi.size());
|
||||
ASSERT_EQ(GetTextContextTokenId(sequence)[1], vi[1]);
|
||||
ClearTextContextTokenId(&sequence);
|
||||
ASSERT_FALSE(HasTextContextTokenId(sequence));
|
||||
}
|
||||
|
||||
TEST(MediaSequenceTest, RoundTripTextContextEmbedding) {
|
||||
tensorflow::SequenceExample sequence;
|
||||
ASSERT_FALSE(HasTextContextEmbedding(sequence));
|
||||
std::vector<float> vi = {47., 35.};
|
||||
SetTextContextEmbedding(vi, &sequence);
|
||||
ASSERT_TRUE(HasTextContextEmbedding(sequence));
|
||||
ASSERT_EQ(GetTextContextEmbedding(sequence).size(), vi.size());
|
||||
ASSERT_EQ(GetTextContextEmbedding(sequence)[1], vi[1]);
|
||||
ClearTextContextEmbedding(&sequence);
|
||||
ASSERT_FALSE(HasTextContextEmbedding(sequence));
|
||||
}
|
||||
|
||||
TEST(MediaSequenceTest, RoundTripTextContent) {
|
||||
tensorflow::SequenceExample sequence;
|
||||
std::vector<std::string> text = {"test", "again"};
|
||||
|
|
|
@ -129,6 +129,8 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
ms.add_bbox_embedding_confidence((0.47, 0.49), example)
|
||||
ms.set_text_language(b"test", example)
|
||||
ms.set_text_context_content(b"text", example)
|
||||
ms.set_text_context_token_id([47, 49], example)
|
||||
ms.set_text_context_embedding([0.47, 0.49], example)
|
||||
ms.add_text_content(b"one", example)
|
||||
ms.add_text_timestamp(47, example)
|
||||
ms.add_text_confidence(0.47, example)
|
||||
|
@ -260,6 +262,29 @@ class MediaSequenceTest(tf.test.TestCase):
|
|||
self.assertFalse(ms.has_feature_dimensions(example, "1"))
|
||||
self.assertFalse(ms.has_feature_dimensions(example, "2"))
|
||||
|
||||
def test_text_context_round_trip(self):
|
||||
example = tf.train.SequenceExample()
|
||||
text_content = b"text content"
|
||||
text_token_ids = np.array([1, 2, 3, 4])
|
||||
text_embeddings = np.array([0.1, 0.2, 0.3, 0.4])
|
||||
self.assertFalse(ms.has_text_context_embedding(example))
|
||||
self.assertFalse(ms.has_text_context_token_id(example))
|
||||
self.assertFalse(ms.has_text_context_content(example))
|
||||
ms.set_text_context_content(text_content, example)
|
||||
ms.set_text_context_token_id(text_token_ids, example)
|
||||
ms.set_text_context_embedding(text_embeddings, example)
|
||||
self.assertEqual(text_content, ms.get_text_context_content(example))
|
||||
self.assertAllClose(text_token_ids, ms.get_text_context_token_id(example))
|
||||
self.assertAllClose(text_embeddings, ms.get_text_context_embedding(example))
|
||||
self.assertTrue(ms.has_text_context_embedding(example))
|
||||
self.assertTrue(ms.has_text_context_token_id(example))
|
||||
self.assertTrue(ms.has_text_context_content(example))
|
||||
ms.clear_text_context_content(example)
|
||||
ms.clear_text_context_token_id(example)
|
||||
ms.clear_text_context_embedding(example)
|
||||
self.assertFalse(ms.has_text_context_embedding(example))
|
||||
self.assertFalse(ms.has_text_context_token_id(example))
|
||||
self.assertFalse(ms.has_text_context_content(example))
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user