MCPcopy Index your code
hub / github.com/dmlc/dgl / graph_classify_task

Function graph_classify_task

examples/pytorch/diffpool/train.py:159–239  ·  view source on GitHub ↗

perform graph classification task

(prog_args)

Source from the content-addressed store, hash-verified

157
158
159def graph_classify_task(prog_args):
160 """
161 perform graph classification task
162 """
163
164 dataset = tu.LegacyTUDataset(name=prog_args.dataset)
165 train_size = int(prog_args.train_ratio * len(dataset))
166 test_size = int(prog_args.test_ratio * len(dataset))
167 val_size = int(len(dataset) - train_size - test_size)
168
169 dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
170 dataset, (train_size, val_size, test_size)
171 )
172 train_dataloader = prepare_data(
173 dataset_train, prog_args, train=True, pre_process=pre_process
174 )
175 val_dataloader = prepare_data(
176 dataset_val, prog_args, train=False, pre_process=pre_process
177 )
178 test_dataloader = prepare_data(
179 dataset_test, prog_args, train=False, pre_process=pre_process
180 )
181 input_dim, label_dim, max_num_node = dataset.statistics()
182 print("++++++++++STATISTICS ABOUT THE DATASET")
183 print("dataset feature dimension is", input_dim)
184 print("dataset label dimension is", label_dim)
185 print("the max num node is", max_num_node)
186 print("number of graphs is", len(dataset))
187 # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"
188
189 hidden_dim = 64 # used to be 64
190 embedding_dim = 64
191
192 # calculate assignment dimension: pool_ratio * largest graph's maximum
193 # number of nodes in the dataset
194 assign_dim = int(max_num_node * prog_args.pool_ratio)
195 print("++++++++++MODEL STATISTICS++++++++")
196 print("model hidden dim is", hidden_dim)
197 print("model embedding dim for graph instance embedding", embedding_dim)
198 print("initial batched pool graph dim is", assign_dim)
199 activation = F.relu
200
201 # initialize model
202 # 'diffpool' : diffpool
203 model = DiffPool(
204 input_dim,
205 hidden_dim,
206 embedding_dim,
207 label_dim,
208 activation,
209 prog_args.gc_per_block,
210 prog_args.dropout,
211 prog_args.num_pool,
212 prog_args.linkpred,
213 prog_args.batch_size,
214 "meanpool",
215 assign_dim,
216 prog_args.pool_ratio,

Callers 1

mainFunction · 0.85

Calls 9

statisticsMethod · 0.95
DiffPoolClass · 0.90
load_state_dictMethod · 0.80
cudaMethod · 0.80
formatMethod · 0.80
prepare_dataFunction · 0.70
trainFunction · 0.70
evaluateFunction · 0.70
loadMethod · 0.45

Tested by

no test coverage detected