| 2 | |
| 3 | |
| 4 | def build_model(args): |
| 5 | num_heads = args.num_heads |
| 6 | num_out_heads = args.num_out_heads |
| 7 | num_hidden = args.num_hidden |
| 8 | num_layers = args.num_layers |
| 9 | residual = args.residual |
| 10 | attn_drop = args.attn_drop |
| 11 | in_drop = args.in_drop |
| 12 | norm = args.norm |
| 13 | negative_slope = args.negative_slope |
| 14 | encoder_type = args.encoder |
| 15 | decoder_type = args.decoder |
| 16 | mask_rate = args.mask_rate |
| 17 | remask_rate = args.remask_rate |
| 18 | mask_method = args.mask_method |
| 19 | drop_edge_rate = args.drop_edge_rate |
| 20 | |
| 21 | activation = args.activation |
| 22 | loss_fn = args.loss_fn |
| 23 | alpha_l = args.alpha_l |
| 24 | |
| 25 | num_features = args.num_features |
| 26 | num_dec_layers = args.num_dec_layers |
| 27 | num_remasking = args.num_remasking |
| 28 | lam = args.lam |
| 29 | delayed_ema_epoch = args.delayed_ema_epoch |
| 30 | replace_rate = args.replace_rate |
| 31 | remask_method = args.remask_method |
| 32 | momentum = args.momentum |
| 33 | |
| 34 | model = PreModel( |
| 35 | in_dim=num_features, |
| 36 | num_hidden=num_hidden, |
| 37 | num_layers=num_layers, |
| 38 | num_dec_layers=num_dec_layers, |
| 39 | num_remasking=num_remasking, |
| 40 | nhead=num_heads, |
| 41 | nhead_out=num_out_heads, |
| 42 | activation=activation, |
| 43 | feat_drop=in_drop, |
| 44 | attn_drop=attn_drop, |
| 45 | negative_slope=negative_slope, |
| 46 | residual=residual, |
| 47 | encoder_type=encoder_type, |
| 48 | decoder_type=decoder_type, |
| 49 | mask_rate=mask_rate, |
| 50 | remask_rate=remask_rate, |
| 51 | mask_method=mask_method, |
| 52 | norm=norm, |
| 53 | loss_fn=loss_fn, |
| 54 | drop_edge_rate=drop_edge_rate, |
| 55 | alpha_l=alpha_l, |
| 56 | lam=lam, |
| 57 | delayed_ema_epoch=delayed_ema_epoch, |
| 58 | replace_rate=replace_rate, |
| 59 | remask_method=remask_method, |
| 60 | momentum=momentum |
| 61 | ) |