Fix load_metadata_buffer for empty metadata

PiperOrigin-RevId: 502870428
This commit is contained in:
MediaPipe Team 2023-01-18 07:26:38 -08:00 committed by Copybara-Service
parent e484bd681e
commit 3688757d17
2 changed files with 26 additions and 2 deletions

View File

@ -860,6 +860,8 @@ def get_metadata_buffer(model_buf):
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
if metadata.DataLength() == 0:
continue
return metadata.DataAsNumpy().tobytes()
return None

View File

@ -550,7 +550,7 @@ class MetadataPopulatorTest(MetadataTest):
("The number of output tensors (1) should match the number of "
"output tensor metadata (0)"), str(error.exception))
def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
def testLoadMetadataAndAssociatedFilesShouldSucceed(self):
# Create a src model with metadata and two associated files.
src_model_buf = self._create_model_buf()
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
@ -566,7 +566,7 @@ class MetadataPopulatorTest(MetadataTest):
populator_src.get_model_buffer())
populator_dst.populate()
# Tests if the metadata and associated files are populated correctly.
# Test if the metadata and associated files are populated correctly.
dst_model_file = self.create_tempfile().full_path
with open(dst_model_file, "wb") as f:
f.write(populator_dst.get_model_buffer())
@ -575,6 +575,28 @@ class MetadataPopulatorTest(MetadataTest):
recorded_files = populator_dst.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
def testLoadMetadataAndAssociatedFilesShouldSucceedOnEmptyMetadata(self):
# When the user hasn't specified the metadata, but only the associated
# files, an empty metadata buffer is created. Previously, it caused an
# exception when reading.
# Create a source model with two associated files but no metadata.
src_model_buf = self._create_model_buf()
populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
populator_src.load_associated_files([self._file1, self._file2])
populator_src.populate()
# Create a model to be populated with the files from `src_model_buf`.
dst_model_buf = self._create_model_buf()
populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
populator_dst.load_metadata_and_associated_files(
populator_src.get_model_buffer())
populator_dst.populate()
# Test if the metadata and associated files are populated correctly.
packed_files = populator_dst.get_packed_associated_file_list()
self.assertEqual(set(packed_files), set(self.expected_recorded_files))
@parameterized.named_parameters(
{
"testcase_name": "InputTensorWithBert",