(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1),
retry=5)
| 55 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') |
| 56 | |
| 57 | def cache_video(tensor, |
| 58 | save_file=None, |
| 59 | fps=30, |
| 60 | suffix='.mp4', |
| 61 | nrow=8, |
| 62 | normalize=True, |
| 63 | value_range=(-1, 1), |
| 64 | retry=5): |
| 65 | # cache file |
| 66 | cache_file = osp.join('/tmp', rand_name( |
| 67 | suffix=suffix)) if save_file is None else save_file |
| 68 | |
| 69 | # save to cache |
| 70 | error = None |
| 71 | for _ in range(retry): |
| 72 | try: |
| 73 | # preprocess |
| 74 | tensor = tensor.clamp(min(value_range), max(value_range)) |
| 75 | tensor = torch.stack([ |
| 76 | torchvision.utils.make_grid( |
| 77 | u, nrow=nrow, normalize=normalize, value_range=value_range) |
| 78 | for u in tensor.unbind(2) |
| 79 | ], |
| 80 | dim=1).permute(1, 2, 3, 0) |
| 81 | tensor = (tensor * 255).type(torch.uint8).cpu() |
| 82 | |
| 83 | # write video |
| 84 | writer = imageio.get_writer( |
| 85 | cache_file, fps=fps, codec='libx264', quality=8) |
| 86 | for frame in tensor.numpy(): |
| 87 | writer.append_data(frame) |
| 88 | writer.close() |
| 89 | return cache_file |
| 90 | except Exception as e: |
| 91 | error = e |
| 92 | continue |
| 93 | else: |
| 94 | print(f'cache_video failed, error: {error}', flush=True) |
| 95 | return None |
| 96 | |
| 97 | |
| 98 | def cache_image(tensor, |
nothing calls this directly
no test coverage detected