Fix load_metadata_buffer
for empty metadata
PiperOrigin-RevId: 502870428
This commit is contained in:
parent
e484bd681e
commit
3688757d17
|
@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf):
|
||||||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||||
buffer_index = meta.Buffer()
|
buffer_index = meta.Buffer()
|
||||||
metadata = tflite_model.Buffers(buffer_index)
|
metadata = tflite_model.Buffers(buffer_index)
|
||||||
|
if metadata.DataLength() == 0:
|
||||||
|
continue
|
||||||
return metadata.DataAsNumpy().tobytes()
|
return metadata.DataAsNumpy().tobytes()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest):
|
||||||
("The number of output tensors (1) should match the number of "
|
("The number of output tensors (1) should match the number of "
|
||||||
"output tensor metadata (0)"), str(error.exception))
|
"output tensor metadata (0)"), str(error.exception))
|
||||||
|
|
||||||
def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
|
def testLoadMetadataAndAssociatedFilesShouldSucceed(self):
|
||||||
# Create a src model with metadata and two associated files.
|
# Create a src model with metadata and two associated files.
|
||||||
src_model_buf = self._create_model_buf()
|
src_model_buf = self._create_model_buf()
|
||||||
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
||||||
|
@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest):
|
||||||
populator_src.get_model_buffer())
|
populator_src.get_model_buffer())
|
||||||
populator_dst.populate()
|
populator_dst.populate()
|
||||||
|
|
||||||
# Tests if the metadata and associated files are populated correctly.
|
# Test if the metadata and associated files are populated correctly.
|
||||||
dst_model_file = self.create_tempfile().full_path
|
dst_model_file = self.create_tempfile().full_path
|
||||||
with open(dst_model_file, "wb") as f:
|
with open(dst_model_file, "wb") as f:
|
||||||
f.write(populator_dst.get_model_buffer())
|
f.write(populator_dst.get_model_buffer())
|
||||||
|
@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest):
|
||||||
recorded_files = populator_dst.get_recorded_associated_file_list()
|
recorded_files = populator_dst.get_recorded_associated_file_list()
|
||||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||||
|
|
||||||
|
def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self):
|
||||||
|
# When the user hasn't specified the metadata, but only the associated
|
||||||
|
# files, an empty metadata buffer is created. Previously, it caused an
|
||||||
|
# exception when reading.
|
||||||
|
|
||||||
|
# Create a source model with two associated files but no metadata.
|
||||||
|
src_model_buf = self._create_model_buf()
|
||||||
|
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
|
||||||
|
populator_src.load_associated_files([self._file1, self._file2])
|
||||||
|
populator_src.populate()
|
||||||
|
|
||||||
|
# Create a model to be populated with the files from `src_model_buf`.
|
||||||
|
dst_model_buf = self._create_model_buf()
|
||||||
|
populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
|
||||||
|
populator_dst.load_metadata_and_associated_files(
|
||||||
|
populator_src.get_model_buffer())
|
||||||
|
populator_dst.populate()
|
||||||
|
|
||||||
|
# Test if the metadata and associated files are populated correctly.
|
||||||
|
packed_files = populator_dst.get_packed_associated_file_list()
|
||||||
|
self.assertEqual(set(packed_files), set(self.expected_recorded_files))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
"testcase_name": "InputTensorWithBert",
|
"testcase_name": "InputTensorWithBert",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user