Add keys for the context that better match the featurelist for text.
PiperOrigin-RevId: 544430289
This commit is contained in:
parent
0bb4ee8941
commit
52cea59d41
|
@ -593,6 +593,8 @@ ground truth transcripts.
|
||||||
|-----|------|------------------------|-------------|
|
|-----|------|------------------------|-------------|
|
||||||
|`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.|
|
|`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.|
|
||||||
|`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.|
|
|`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.|
|
||||||
|
|`text/context/token_id`|context int list|`set_text_context_token_id` / `SetTextContextTokenId`|Storage for large blocks of text in the context as token ids.|
|
||||||
|
|`text/context/embedding`|context float list|`set_text_context_embedding` / `SetTextContextEmbedding`|Storage for large blocks of text in the context as embeddings.|
|
||||||
|`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.|
|
|`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.|
|
||||||
|`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.|
|
|`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.|
|
||||||
|`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.|
|
|`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.|
|
||||||
|
|
|
@ -634,6 +634,10 @@ PREFIXED_IMAGE(InstanceSegmentation, kInstanceSegmentationPrefix);
|
||||||
const char kTextLanguageKey[] = "text/language";
|
const char kTextLanguageKey[] = "text/language";
|
||||||
// A large block of text that applies to the media.
|
// A large block of text that applies to the media.
|
||||||
const char kTextContextContentKey[] = "text/context/content";
|
const char kTextContextContentKey[] = "text/context/content";
|
||||||
|
// A large block of text that applies to the media as token ids.
|
||||||
|
const char kTextContextTokenIdKey[] = "text/context/token_id";
|
||||||
|
// A large block of text that applies to the media as embeddings.
|
||||||
|
const char kTextContextEmbeddingKey[] = "text/context/embedding";
|
||||||
|
|
||||||
// Feature list keys:
|
// Feature list keys:
|
||||||
// The text contents for a given time.
|
// The text contents for a given time.
|
||||||
|
@ -651,6 +655,8 @@ const char kTextTokenIdKey[] = "text/token/id";
|
||||||
|
|
||||||
BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey);
|
BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey);
|
||||||
BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey);
|
BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey);
|
||||||
|
VECTOR_INT64_CONTEXT_FEATURE(TextContextTokenId, kTextContextTokenIdKey);
|
||||||
|
VECTOR_FLOAT_CONTEXT_FEATURE(TextContextEmbedding, kTextContextEmbeddingKey);
|
||||||
BYTES_FEATURE_LIST(TextContent, kTextContentKey);
|
BYTES_FEATURE_LIST(TextContent, kTextContentKey);
|
||||||
INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey);
|
INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey);
|
||||||
INT64_FEATURE_LIST(TextDuration, kTextDurationKey);
|
INT64_FEATURE_LIST(TextDuration, kTextDurationKey);
|
||||||
|
|
|
@ -601,6 +601,10 @@ _create_image_with_prefix("instance_segmentation", INSTANCE_SEGMENTATION_PREFIX)
|
||||||
TEXT_LANGUAGE_KEY = "text/language"
|
TEXT_LANGUAGE_KEY = "text/language"
|
||||||
# A large block of text that applies to the media.
|
# A large block of text that applies to the media.
|
||||||
TEXT_CONTEXT_CONTENT_KEY = "text/context/content"
|
TEXT_CONTEXT_CONTENT_KEY = "text/context/content"
|
||||||
|
# A large block of text that applies to the media as token ids.
|
||||||
|
TEXT_CONTEXT_TOKEN_ID_KEY = "text/context/token_id"
|
||||||
|
# A large block of text that applies to the media as embeddings.
|
||||||
|
TEXT_CONTEXT_EMBEDDING_KEY = "text/context/embedding"
|
||||||
|
|
||||||
# The text contents for a given time.
|
# The text contents for a given time.
|
||||||
TEXT_CONTENT_KEY = "text/content"
|
TEXT_CONTENT_KEY = "text/content"
|
||||||
|
@ -619,6 +623,10 @@ msu.create_bytes_context_feature(
|
||||||
"text_language", TEXT_LANGUAGE_KEY, module_dict=globals())
|
"text_language", TEXT_LANGUAGE_KEY, module_dict=globals())
|
||||||
msu.create_bytes_context_feature(
|
msu.create_bytes_context_feature(
|
||||||
"text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals())
|
"text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals())
|
||||||
|
msu.create_int_list_context_feature(
|
||||||
|
"text_context_token_id", TEXT_CONTEXT_TOKEN_ID_KEY, module_dict=globals())
|
||||||
|
msu.create_float_list_context_feature(
|
||||||
|
"text_context_embedding", TEXT_CONTEXT_EMBEDDING_KEY, module_dict=globals())
|
||||||
msu.create_bytes_feature_list(
|
msu.create_bytes_feature_list(
|
||||||
"text_content", TEXT_CONTENT_KEY, module_dict=globals())
|
"text_content", TEXT_CONTENT_KEY, module_dict=globals())
|
||||||
msu.create_int_feature_list(
|
msu.create_int_feature_list(
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
@ -711,6 +712,30 @@ TEST(MediaSequenceTest, RoundTripTextContextContent) {
|
||||||
ASSERT_FALSE(HasTextContextContent(sequence));
|
ASSERT_FALSE(HasTextContextContent(sequence));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MediaSequenceTest, RoundTripTextContextTokenId) {
|
||||||
|
tensorflow::SequenceExample sequence;
|
||||||
|
ASSERT_FALSE(HasTextContextTokenId(sequence));
|
||||||
|
std::vector<int64_t> vi = {47, 35};
|
||||||
|
SetTextContextTokenId(vi, &sequence);
|
||||||
|
ASSERT_TRUE(HasTextContextTokenId(sequence));
|
||||||
|
ASSERT_EQ(GetTextContextTokenId(sequence).size(), vi.size());
|
||||||
|
ASSERT_EQ(GetTextContextTokenId(sequence)[1], vi[1]);
|
||||||
|
ClearTextContextTokenId(&sequence);
|
||||||
|
ASSERT_FALSE(HasTextContextTokenId(sequence));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MediaSequenceTest, RoundTripTextContextEmbedding) {
|
||||||
|
tensorflow::SequenceExample sequence;
|
||||||
|
ASSERT_FALSE(HasTextContextEmbedding(sequence));
|
||||||
|
std::vector<float> vi = {47., 35.};
|
||||||
|
SetTextContextEmbedding(vi, &sequence);
|
||||||
|
ASSERT_TRUE(HasTextContextEmbedding(sequence));
|
||||||
|
ASSERT_EQ(GetTextContextEmbedding(sequence).size(), vi.size());
|
||||||
|
ASSERT_EQ(GetTextContextEmbedding(sequence)[1], vi[1]);
|
||||||
|
ClearTextContextEmbedding(&sequence);
|
||||||
|
ASSERT_FALSE(HasTextContextEmbedding(sequence));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(MediaSequenceTest, RoundTripTextContent) {
|
TEST(MediaSequenceTest, RoundTripTextContent) {
|
||||||
tensorflow::SequenceExample sequence;
|
tensorflow::SequenceExample sequence;
|
||||||
std::vector<std::string> text = {"test", "again"};
|
std::vector<std::string> text = {"test", "again"};
|
||||||
|
|
|
@ -129,6 +129,8 @@ class MediaSequenceTest(tf.test.TestCase):
|
||||||
ms.add_bbox_embedding_confidence((0.47, 0.49), example)
|
ms.add_bbox_embedding_confidence((0.47, 0.49), example)
|
||||||
ms.set_text_language(b"test", example)
|
ms.set_text_language(b"test", example)
|
||||||
ms.set_text_context_content(b"text", example)
|
ms.set_text_context_content(b"text", example)
|
||||||
|
ms.set_text_context_token_id([47, 49], example)
|
||||||
|
ms.set_text_context_embedding([0.47, 0.49], example)
|
||||||
ms.add_text_content(b"one", example)
|
ms.add_text_content(b"one", example)
|
||||||
ms.add_text_timestamp(47, example)
|
ms.add_text_timestamp(47, example)
|
||||||
ms.add_text_confidence(0.47, example)
|
ms.add_text_confidence(0.47, example)
|
||||||
|
@ -260,6 +262,29 @@ class MediaSequenceTest(tf.test.TestCase):
|
||||||
self.assertFalse(ms.has_feature_dimensions(example, "1"))
|
self.assertFalse(ms.has_feature_dimensions(example, "1"))
|
||||||
self.assertFalse(ms.has_feature_dimensions(example, "2"))
|
self.assertFalse(ms.has_feature_dimensions(example, "2"))
|
||||||
|
|
||||||
|
def test_text_context_round_trip(self):
|
||||||
|
example = tf.train.SequenceExample()
|
||||||
|
text_content = b"text content"
|
||||||
|
text_token_ids = np.array([1, 2, 3, 4])
|
||||||
|
text_embeddings = np.array([0.1, 0.2, 0.3, 0.4])
|
||||||
|
self.assertFalse(ms.has_text_context_embedding(example))
|
||||||
|
self.assertFalse(ms.has_text_context_token_id(example))
|
||||||
|
self.assertFalse(ms.has_text_context_content(example))
|
||||||
|
ms.set_text_context_content(text_content, example)
|
||||||
|
ms.set_text_context_token_id(text_token_ids, example)
|
||||||
|
ms.set_text_context_embedding(text_embeddings, example)
|
||||||
|
self.assertEqual(text_content, ms.get_text_context_content(example))
|
||||||
|
self.assertAllClose(text_token_ids, ms.get_text_context_token_id(example))
|
||||||
|
self.assertAllClose(text_embeddings, ms.get_text_context_embedding(example))
|
||||||
|
self.assertTrue(ms.has_text_context_embedding(example))
|
||||||
|
self.assertTrue(ms.has_text_context_token_id(example))
|
||||||
|
self.assertTrue(ms.has_text_context_content(example))
|
||||||
|
ms.clear_text_context_content(example)
|
||||||
|
ms.clear_text_context_token_id(example)
|
||||||
|
ms.clear_text_context_embedding(example)
|
||||||
|
self.assertFalse(ms.has_text_context_embedding(example))
|
||||||
|
self.assertFalse(ms.has_text_context_token_id(example))
|
||||||
|
self.assertFalse(ms.has_text_context_content(example))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user