Run Model
| 113 | |
| 114 | |
| 115 | class Running(object): |
| 116 | """Run Model""" |
| 117 | |
| 118 | def __init__(self, args, device_id): |
| 119 | """ |
| 120 | :param args: parser.parse_args() |
| 121 | :param device_id: 0 or -1 |
| 122 | """ |
| 123 | self.args = args |
| 124 | self.device_id = device_id |
| 125 | self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', |
| 126 | 'rnn_size'] |
| 127 | |
| 128 | self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda" |
| 129 | logger.info('Device ID %d' % self.device_id) |
| 130 | logger.info('Device %s' % self.device) |
| 131 | torch.manual_seed(self.args.seed) |
| 132 | random.seed(self.args.seed) |
| 133 | |
| 134 | if self.device_id >= 0: |
| 135 | torch.cuda.set_device(self.device_id) |
| 136 | |
| 137 | init_logger(args.log_file) |
| 138 | |
| 139 | def baseline(self, cal_lead=False, cal_oracle=False): |
| 140 | test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False), |
| 141 | self.args.batch_size, self.device, shuffle=False, is_test=True) |
| 142 | |
| 143 | trainer = build_trainer(self.args, self.device_id, None, None) |
| 144 | |
| 145 | if cal_lead: |
| 146 | trainer.test(test_iter, 0, cal_lead=True) |
| 147 | elif cal_oracle: |
| 148 | trainer.test(test_iter, 0, cal_oracle=True) |
| 149 | |
| 150 | def train_iter(self): |
| 151 | return data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'train', shuffle=True), |
| 152 | self.args.batch_size, self.device, shuffle=True, is_test=False) |
| 153 | |
| 154 | def train(self): |
| 155 | model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=True) |
| 156 | |
| 157 | if self.args.train_from: |
| 158 | logger.info('Loading checkpoint from %s' % self.args.train_from) |
| 159 | checkpoint = torch.load(self.args.train_from, map_location=lambda storage, loc: storage) |
| 160 | opt = vars(checkpoint['opt']) |
| 161 | for k in opt.keys(): |
| 162 | if k in self.model_flags: |
| 163 | setattr(self.args, k, opt[k]) |
| 164 | model.load_cp(checkpoint) |
| 165 | optimizer = model_builder.build_optim(self.args, model, checkpoint) |
| 166 | else: |
| 167 | optimizer = model_builder.build_optim(self.args, model, None) |
| 168 | |
| 169 | logger.info(model) |
| 170 | trainer = build_trainer(self.args, self.device_id, model, optimizer) |
| 171 | trainer.train(self.train_iter, self.args.train_steps) |
| 172 |
no outgoing calls
no test coverage detected