Fix load_metadata_buffer
for empty metadata
PiperOrigin-RevId: 502870428
This commit is contained in:
parent
e484bd681e
commit
3688757d17
|
@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf):
|
|||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||
buffer_index = meta.Buffer()
|
||||
metadata = tflite_model.Buffers(buffer_index)
|
||||
if metadata.DataLength() == 0:
|
||||
continue
|
||||
return metadata.DataAsNumpy().tobytes()
|
||||
|
||||
return None
|
||||
|
|
|
@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest):
|
|||
("The number of output tensors (1) should match the number of "
|
||||
"output tensor metadata (0)"), str(error.exception))
|
||||
|
||||
def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
|
||||
def testLoadMetadataAndAssociatedFilesShouldSucceed(self):
|
||||
# Create a src model with metadata and two associated files.
|
||||
src_model_buf = self._create_model_buf()
|
||||
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
||||
|
@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest):
|
|||
populator_src.get_model_buffer())
|
||||
populator_dst.populate()
|
||||
|
||||
# Tests if the metadata and associated files are populated correctly.
|
||||
# Test if the metadata and associated files are populated correctly.
|
||||
dst_model_file = self.create_tempfile().full_path
|
||||
with open(dst_model_file, "wb") as f:
|
||||
f.write(populator_dst.get_model_buffer())
|
||||
|
@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest):
|
|||
recorded_files = populator_dst.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self):
|
||||
# When the user hasn't specified the metadata, but only the associated
|
||||
# files, an empty metadata buffer is created. Previously, it caused an
|
||||
# exception when reading.
|
||||
|
||||
# Create a source model with two associated files but no metadata.
|
||||
src_model_buf = self._create_model_buf()
|
||||
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
||||
populator_src.load_associated_files([self._file1, self._file2])
|
||||
populator_src.populate()
|
||||
|
||||
# Create a model to be populated with the files from `src_model_buf`.
|
||||
dst_model_buf = self._create_model_buf()
|
||||
populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
|
||||
populator_dst.load_metadata_and_associated_files(
|
||||
populator_src.get_model_buffer())
|
||||
populator_dst.populate()
|
||||
|
||||
# Test if the metadata and associated files are populated correctly.
|
||||
packed_files = populator_dst.get_packed_associated_file_list()
|
||||
self.assertEqual(set(packed_files), set(self.expected_recorded_files))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "InputTensorWithBert",
|
||||
|
|
Loading…
Reference in New Issue
Block a user