(source, dest_size)
| 1 | import torch |
| 2 | |
| 3 | def tensor_to_size(source, dest_size): |
| 4 | if isinstance(dest_size, torch.Tensor): |
| 5 | dest_size = dest_size.shape[0] |
| 6 | source_size = source.shape[0] |
| 7 | |
| 8 | if source_size < dest_size: |
| 9 | shape = [dest_size - source_size] + [1]*(source.dim()-1) |
| 10 | source = torch.cat((source, source[-1:].repeat(shape)), dim=0) |
| 11 | elif source_size > dest_size: |
| 12 | source = source[:dest_size] |
| 13 | |
| 14 | return source |
| 15 | |
| 16 | def tensor_to_image(tensor): |
| 17 | image = tensor.mul(255).clamp(0, 255).byte().cpu() |
no outgoing calls
no test coverage detected