(args)
| 107 | |
| 108 | |
| 109 | def main(args): |
| 110 | device = args.device if args.device >= 0 else "cpu" |
| 111 | seeds = args.seeds |
| 112 | dataset_name = args.dataset |
| 113 | max_epoch = args.max_epoch |
| 114 | max_epoch_f = args.max_epoch_f |
| 115 | num_hidden = args.num_hidden |
| 116 | num_layers = args.num_layers |
| 117 | encoder_type = args.encoder |
| 118 | decoder_type = args.decoder |
| 119 | replace_rate = args.replace_rate |
| 120 | |
| 121 | optim_type = args.optimizer |
| 122 | loss_fn = args.loss_fn |
| 123 | |
| 124 | lr = args.lr |
| 125 | weight_decay = args.weight_decay |
| 126 | lr_f = args.lr_f |
| 127 | weight_decay_f = args.weight_decay_f |
| 128 | linear_prob = args.linear_prob |
| 129 | load_model = args.load_model |
| 130 | save_model = args.save_model |
| 131 | logs = args.logging |
| 132 | use_scheduler = args.scheduler |
| 133 | pooling = args.pooling |
| 134 | deg4feat = args.deg4feat |
| 135 | batch_size = args.batch_size |
| 136 | |
| 137 | graphs, (num_features, num_classes) = load_graph_classification_dataset(dataset_name, deg4feat=deg4feat) |
| 138 | args.num_features = num_features |
| 139 | |
| 140 | train_loader = DataLoader(graphs, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True) |
| 141 | eval_loader = DataLoader(graphs, collate_fn=collate_fn, batch_size=batch_size, shuffle=False) |
| 142 | |
| 143 | if pooling == "mean": |
| 144 | pooler = batch_mean_pooling |
| 145 | elif pooling == "max": |
| 146 | pooler = batch_max_pooling |
| 147 | elif pooling == "sum": |
| 148 | pooler = batch_sum_pooling |
| 149 | else: |
| 150 | raise NotImplementedError |
| 151 | |
| 152 | acc_list = [] |
| 153 | for i, seed in enumerate(seeds): |
| 154 | print(f"####### Run {i} for seed {seed}") |
| 155 | set_random_seed(seed) |
| 156 | |
| 157 | if logs: |
| 158 | logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}") |
| 159 | else: |
| 160 | logger = None |
| 161 | |
| 162 | model = build_model(args) |
| 163 | model.to(device) |
| 164 | optimizer = create_optimizer(optim_type, model, lr, weight_decay) |
| 165 | |
| 166 | if use_scheduler: |
no test coverage detected