Remove batch dimension from the output of tflite_with_tokenizer in text classifier.
PiperOrigin-RevId: 581292824
This commit is contained in:
		
							parent
							
								
									d772bf8134
								
							
						
					
					
						commit
						64b21d758e
					
				| 
						 | 
					@ -32,4 +32,4 @@ class ModelWithTokenizer(tf.keras.Model):
 | 
				
			||||||
    x = self._tokenizer.process_fn(input_tensor)
 | 
					    x = self._tokenizer.process_fn(input_tensor)
 | 
				
			||||||
    x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
 | 
					    x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
 | 
				
			||||||
    x = self._model(x)
 | 
					    x = self._model(x)
 | 
				
			||||||
    return x
 | 
					    return x[0]  # TODO: Add back the batch dimension
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,7 +97,7 @@ class BertTokenizerTest(tf.test.TestCase):
 | 
				
			||||||
        self._tokenizer, self._model
 | 
					        self._tokenizer, self._model
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    output = model(tf.constant(["Example input".encode("utf-8")]))
 | 
					    output = model(tf.constant(["Example input".encode("utf-8")]))
 | 
				
			||||||
    self.assertAllEqual(output.shape, (1, 2))
 | 
					    self.assertAllEqual(output.shape, (2,))
 | 
				
			||||||
    self.assertEqual(tf.reduce_sum(output), 1)
 | 
					    self.assertEqual(tf.reduce_sum(output), 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user