Remove batch dimension from the output of tflite_with_tokenizer in text classifier.
PiperOrigin-RevId: 581292824
This commit is contained in:
		
							parent
							
								
									d772bf8134
								
							
						
					
					
						commit
						64b21d758e
					
				| 
						 | 
				
			
			@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model):
 | 
			
		|||
    x = self._tokenizer.process_fn(input_tensor)
 | 
			
		||||
    x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
 | 
			
		||||
    x = self._model(x)
 | 
			
		||||
    return x
 | 
			
		||||
    return x[0]  # TODO: Add back the batch dimension
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase):
 | 
			
		|||
        self._tokenizer, self._model
 | 
			
		||||
    )
 | 
			
		||||
    output = model(tf.constant(["Example input".encode("utf-8")]))
 | 
			
		||||
    self.assertAllEqual(output.shape, (1, 2))
 | 
			
		||||
    self.assertAllEqual(output.shape, (2,))
 | 
			
		||||
    self.assertEqual(tf.reduce_sum(output), 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user