| 331 | |
| 332 | |
| 333 | def to_h5py(value, h5: h5py.File, key: str = None, compression: str = 'gzip'): |
| 334 | if isinstance(value, torch.Tensor): |
| 335 | value = value.detach().cpu().numpy() |
| 336 | if isinstance(value, np.ndarray): |
| 337 | h5.create_dataset(str(key), data=value, compression=compression) |
| 338 | elif isinstance(value, list): |
| 339 | if key is not None: |
| 340 | h5 = h5.create_group(str(key)) |
| 341 | [to_h5py(v, h5, k) for k, v in enumerate(value)] |
| 342 | elif isinstance(value, dict): |
| 343 | if key is not None: |
| 344 | h5 = h5.create_group(str(key)) |
| 345 | [to_h5py(v, h5, k) for k, v in value.items()] |
| 346 | else: |
| 347 | raise NotImplementedError(f'unsupported type to write to h5: {type(value)}') |
| 348 | |
| 349 | |
| 350 | def export_h5(batch: dotdict, filename): |