r"""Convert tensor to image. Args: image_tensor (torch.tensor or list of torch.tensor): If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW). imtype (np.dtype): Type of output image. normalize (bool): Is the input image normalized or not? three_chan
(image_tensor, imtype=np.uint8, normalize=True,
three_channel_output=True)
| 70 | |
| 71 | |
| 72 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, |
| 73 | three_channel_output=True): |
| 74 | r"""Convert tensor to image. |
| 75 | |
| 76 | Args: |
| 77 | image_tensor (torch.tensor or list of torch.tensor): If tensor then |
| 78 | (NxCxHxW) or (NxTxCxHxW) or (CxHxW). |
| 79 | imtype (np.dtype): Type of output image. |
| 80 | normalize (bool): Is the input image normalized or not? |
| 81 | three_channel_output (bool): Should single channel images be made 3 |
| 82 | channel in output? |
| 83 | |
| 84 | Returns: |
| 85 | (numpy.ndarray, list if case 1, 2 above). |
| 86 | """ |
| 87 | if image_tensor is None: |
| 88 | return None |
| 89 | if isinstance(image_tensor, list): |
| 90 | return [tensor2im(x, imtype, normalize) for x in image_tensor] |
| 91 | if image_tensor.dim() == 5 or image_tensor.dim() == 4: |
| 92 | return [tensor2im(image_tensor[idx], imtype, normalize) |
| 93 | for idx in range(image_tensor.size(0))] |
| 94 | |
| 95 | if image_tensor.dim() == 3: |
| 96 | image_numpy = image_tensor.cpu().float().numpy() |
| 97 | if normalize: |
| 98 | image_numpy = (np.transpose( |
| 99 | image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 |
| 100 | else: |
| 101 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 |
| 102 | image_numpy = np.clip(image_numpy, 0, 255) |
| 103 | if image_numpy.shape[2] == 1 and three_channel_output: |
| 104 | image_numpy = np.repeat(image_numpy, 3, axis=2) |
| 105 | elif image_numpy.shape[2] > 3: |
| 106 | image_numpy = image_numpy[:, :, :3] |
| 107 | return image_numpy.astype(imtype) |
| 108 | |
| 109 | |
| 110 | def tensor2label(segmap, n_label=None, imtype=np.uint8, |
no outgoing calls