| 70 | |
| 71 | |
| 72 | def make_resizer(library, filter, output_size): |
| 73 | if library == "PIL": |
| 74 | s1, s2 = output_size |
| 75 | def resize_single_channel(x_np): |
| 76 | img = Image.fromarray(x_np.astype(np.float32), mode='F') |
| 77 | img = img.resize(output_size, resample=dict_name_to_filter[library][filter]) |
| 78 | return np.asarray(img).reshape(s1, s2, 1) |
| 79 | def func(x): |
| 80 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] |
| 81 | x = np.concatenate(x, axis=2).astype(np.float32) |
| 82 | return x |
| 83 | elif library == "PyTorch": |
| 84 | import warnings |
| 85 | # ignore the numpy warnings |
| 86 | warnings.filterwarnings("ignore") |
| 87 | def func(x): |
| 88 | x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...] |
| 89 | x = F.interpolate(x, size=output_size, mode=filter, align_corners=False) |
| 90 | x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255) |
| 91 | return x |
| 92 | else: |
| 93 | raise NotImplementedError('library [%s] is not include' % library) |
| 94 | return func |