Remove batch dimension from the output of tflite_with_tokenizer in text classifier.
PiperOrigin-RevId: 581292824
This commit is contained in:
parent
d772bf8134
commit
64b21d758e
|
@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model):
|
|||
x = self._tokenizer.process_fn(input_tensor)
|
||||
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
|
||||
x = self._model(x)
|
||||
return x
|
||||
return x[0] # TODO: Add back the batch dimension
|
||||
|
|
|
@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase):
|
|||
self._tokenizer, self._model
|
||||
)
|
||||
output = model(tf.constant(["Example input".encode("utf-8")]))
|
||||
self.assertAllEqual(output.shape, (1, 2))
|
||||
self.assertAllEqual(output.shape, (2,))
|
||||
self.assertEqual(tf.reduce_sum(output), 1)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user