Remove batch dimension from the output of tflite_with_tokenizer in text classifier.
PiperOrigin-RevId: 581292824
This commit is contained in:
parent
d772bf8134
commit
64b21d758e
|
@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model):
|
||||||
x = self._tokenizer.process_fn(input_tensor)
|
x = self._tokenizer.process_fn(input_tensor)
|
||||||
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
|
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
|
||||||
x = self._model(x)
|
x = self._model(x)
|
||||||
return x
|
return x[0] # TODO: Add back the batch dimension
|
||||||
|
|
|
@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase):
|
||||||
self._tokenizer, self._model
|
self._tokenizer, self._model
|
||||||
)
|
)
|
||||||
output = model(tf.constant(["Example input".encode("utf-8")]))
|
output = model(tf.constant(["Example input".encode("utf-8")]))
|
||||||
self.assertAllEqual(output.shape, (1, 2))
|
self.assertAllEqual(output.shape, (2,))
|
||||||
self.assertEqual(tf.reduce_sum(output), 1)
|
self.assertEqual(tf.reduce_sum(output), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user