Remove batch dimension from the output of tflite_with_tokenizer in text classifier.

PiperOrigin-RevId: 581292824
This commit is contained in:
MediaPipe Team 2023-11-10 10:00:56 -08:00 committed by Copybara-Service
parent d772bf8134
commit 64b21d758e
2 changed files with 2 additions and 2 deletions

View File

@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model):
x = self._tokenizer.process_fn(input_tensor)
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
x = self._model(x)
return x
return x[0] # TODO: Add back the batch dimension

View File

@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase):
self._tokenizer, self._model
)
output = model(tf.constant(["Example input".encode("utf-8")]))
self.assertAllEqual(output.shape, (1, 2))
self.assertAllEqual(output.shape, (2,))
self.assertEqual(tf.reduce_sum(output), 1)