(t: torch.Tensor)
| 5 | |
| 6 | |
| 7 | def tensor2bytes(t: torch.Tensor): |
| 8 | # t = t.cpu().numpy().tobytes() |
| 9 | # return t |
| 10 | buf = BytesIO() |
| 11 | t = t.detach().cpu() |
| 12 | # 这个地方进行新的empty并复制是因为,torch的tensor save的机制存在问题 |
| 13 | # 如果 t 是从一个大 tensor 上切片复制下来的的tensor, 在save的时候,其 |
| 14 | # 会保存大tensor的所有数据,所以会导致存储开销较大,需要申请一个新的tensor |
| 15 | # 并进行复制,来打断这种联系。 |
| 16 | dest = torch.empty_like(t) |
| 17 | dest.copy_(t) |
| 18 | torch.save(dest, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) |
| 19 | buf.seek(0) |
| 20 | return buf.read() |
| 21 | |
| 22 | |
| 23 | def bytes2tensor(b): |
no test coverage detected