From 64b21d758ed06872e51dfa39e1555258f94e0602 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 10 Nov 2023 10:00:56 -0800 Subject: [PATCH] Remove batch dimension from the output of tflite_with_tokenizer in text classifier. PiperOrigin-RevId: 581292824 --- .../python/text/text_classifier/model_with_tokenizer.py | 2 +- .../python/text/text_classifier/model_with_tokenizer_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py index 95328fb43..a96fe1b84 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer.py @@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model): x = self._tokenizer.process_fn(input_tensor) x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()} x = self._model(x) - return x + return x[0] # TODO: Add back the batch dimension diff --git a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py index f6c5d2477..1da09ab4e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_with_tokenizer_test.py @@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase): self._tokenizer, self._model ) output = model(tf.constant(["Example input".encode("utf-8")])) - self.assertAllEqual(output.shape, (1, 2)) + self.assertAllEqual(output.shape, (2,)) self.assertEqual(tf.reduce_sum(output), 1)