:param args: parser.parse_args() :param device_id: 0 or -1
(self, args, device_id)
| 116 | """Run Model""" |
| 117 | |
| 118 | def __init__(self, args, device_id): |
| 119 | """ |
| 120 | :param args: parser.parse_args() |
| 121 | :param device_id: 0 or -1 |
| 122 | """ |
| 123 | self.args = args |
| 124 | self.device_id = device_id |
| 125 | self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', |
| 126 | 'rnn_size'] |
| 127 | |
| 128 | self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda" |
| 129 | logger.info('Device ID %d' % self.device_id) |
| 130 | logger.info('Device %s' % self.device) |
| 131 | torch.manual_seed(self.args.seed) |
| 132 | random.seed(self.args.seed) |
| 133 | |
| 134 | if self.device_id >= 0: |
| 135 | torch.cuda.set_device(self.device_id) |
| 136 | |
| 137 | init_logger(args.log_file) |
| 138 | |
| 139 | def baseline(self, cal_lead=False, cal_oracle=False): |
| 140 | test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False), |
nothing calls this directly
no test coverage detected