INPUTS model - ['net-lin'] for linearly calibrated network ['net'] for off-the-shelf network ['L2'] for L2 distance in Lab colorspace ['SSIM'] for ssim in RGB colorspace net - ['squeeze','alex','vgg']
(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
use_gpu=True, printNet=False, spatial=False,
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0])
| 26 | return self.model_name |
| 27 | |
| 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, |
| 29 | use_gpu=True, printNet=False, spatial=False, |
| 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): |
| 31 | ''' |
| 32 | INPUTS |
| 33 | model - ['net-lin'] for linearly calibrated network |
| 34 | ['net'] for off-the-shelf network |
| 35 | ['L2'] for L2 distance in Lab colorspace |
| 36 | ['SSIM'] for ssim in RGB colorspace |
| 37 | net - ['squeeze','alex','vgg'] |
| 38 | model_path - if None, will look in weights/[NET_NAME].pth |
| 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM |
| 40 | use_gpu - bool - whether or not to use a GPU |
| 41 | printNet - bool - whether or not to print network architecture out |
| 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions |
| 43 | is_train - bool - [True] for training mode |
| 44 | lr - float - initial learning rate |
| 45 | beta1 - float - initial momentum term for adam |
| 46 | version - 0.1 for latest, 0.0 was original (with a bug) |
| 47 | gpu_ids - int array - [0] by default, gpus to use |
| 48 | ''' |
| 49 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) |
| 50 | |
| 51 | self.model = model |
| 52 | self.net = net |
| 53 | self.is_train = is_train |
| 54 | self.spatial = spatial |
| 55 | self.gpu_ids = gpu_ids |
| 56 | self.model_name = '%s [%s]'%(model,net) |
| 57 | |
| 58 | if(self.model == 'net-lin'): # pretrained net + linear layer |
| 59 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, |
| 60 | use_dropout=True, spatial=spatial, version=version, lpips=True) |
| 61 | kw = {} |
| 62 | if not use_gpu: |
| 63 | kw['map_location'] = 'cpu' |
| 64 | if(model_path is None): |
| 65 | import inspect |
| 66 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) |
| 67 | |
| 68 | if(not is_train): |
| 69 | print('Loading model from: %s'%model_path) |
| 70 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) |
| 71 | |
| 72 | elif(self.model=='net'): # pretrained network |
| 73 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) |
| 74 | elif(self.model in ['L2','l2']): |
| 75 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing |
| 76 | self.model_name = 'L2' |
| 77 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): |
| 78 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) |
| 79 | self.model_name = 'SSIM' |
| 80 | else: |
| 81 | raise ValueError("Model [%s] not recognized." % self.model) |
| 82 | |
| 83 | self.parameters = list(self.net.parameters()) |
| 84 | |
| 85 | if self.is_train: # training mode |
no outgoing calls
no test coverage detected