perform graph classification task
(prog_args)
| 157 | |
| 158 | |
| 159 | def graph_classify_task(prog_args): |
| 160 | """ |
| 161 | perform graph classification task |
| 162 | """ |
| 163 | |
| 164 | dataset = tu.LegacyTUDataset(name=prog_args.dataset) |
| 165 | train_size = int(prog_args.train_ratio * len(dataset)) |
| 166 | test_size = int(prog_args.test_ratio * len(dataset)) |
| 167 | val_size = int(len(dataset) - train_size - test_size) |
| 168 | |
| 169 | dataset_train, dataset_val, dataset_test = torch.utils.data.random_split( |
| 170 | dataset, (train_size, val_size, test_size) |
| 171 | ) |
| 172 | train_dataloader = prepare_data( |
| 173 | dataset_train, prog_args, train=True, pre_process=pre_process |
| 174 | ) |
| 175 | val_dataloader = prepare_data( |
| 176 | dataset_val, prog_args, train=False, pre_process=pre_process |
| 177 | ) |
| 178 | test_dataloader = prepare_data( |
| 179 | dataset_test, prog_args, train=False, pre_process=pre_process |
| 180 | ) |
| 181 | input_dim, label_dim, max_num_node = dataset.statistics() |
| 182 | print("++++++++++STATISTICS ABOUT THE DATASET") |
| 183 | print("dataset feature dimension is", input_dim) |
| 184 | print("dataset label dimension is", label_dim) |
| 185 | print("the max num node is", max_num_node) |
| 186 | print("number of graphs is", len(dataset)) |
| 187 | # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size" |
| 188 | |
| 189 | hidden_dim = 64 # used to be 64 |
| 190 | embedding_dim = 64 |
| 191 | |
| 192 | # calculate assignment dimension: pool_ratio * largest graph's maximum |
| 193 | # number of nodes in the dataset |
| 194 | assign_dim = int(max_num_node * prog_args.pool_ratio) |
| 195 | print("++++++++++MODEL STATISTICS++++++++") |
| 196 | print("model hidden dim is", hidden_dim) |
| 197 | print("model embedding dim for graph instance embedding", embedding_dim) |
| 198 | print("initial batched pool graph dim is", assign_dim) |
| 199 | activation = F.relu |
| 200 | |
| 201 | # initialize model |
| 202 | # 'diffpool' : diffpool |
| 203 | model = DiffPool( |
| 204 | input_dim, |
| 205 | hidden_dim, |
| 206 | embedding_dim, |
| 207 | label_dim, |
| 208 | activation, |
| 209 | prog_args.gc_per_block, |
| 210 | prog_args.dropout, |
| 211 | prog_args.num_pool, |
| 212 | prog_args.linkpred, |
| 213 | prog_args.batch_size, |
| 214 | "meanpool", |
| 215 | assign_dim, |
| 216 | prog_args.pool_ratio, |
no test coverage detected