(self, base_file)
| 111 | return output |
| 112 | |
| 113 | def load_weights(self, base_file): |
| 114 | other, ext = os.path.splitext(base_file) |
| 115 | if ext == '.pkl' or '.pth': |
| 116 | print('Loading weights into state dict...') |
| 117 | self.load_state_dict(torch.load(base_file, |
| 118 | map_location=lambda storage, loc: storage)) |
| 119 | print('Finished!') |
| 120 | else: |
| 121 | print('Sorry only .pth and .pkl files supported.') |
| 122 | |
| 123 | |
| 124 | # This function is derived from torchvision VGG make_layers() |