()
| 43 | |
| 44 | |
| 45 | def build_args(): |
| 46 | parser = argparse.ArgumentParser(description="GAT") |
| 47 | parser.add_argument("--seeds", type=int, nargs="+", default=[0]) |
| 48 | parser.add_argument("--dataset", type=str, default="cora") |
| 49 | parser.add_argument("--device", type=int, default=0) |
| 50 | parser.add_argument("--max_epoch", type=int, default=500, |
| 51 | help="number of training epochs") |
| 52 | parser.add_argument("--warmup_steps", type=int, default=-1) |
| 53 | |
| 54 | parser.add_argument("--num_heads", type=int, default=4, |
| 55 | help="number of hidden attention heads") |
| 56 | parser.add_argument("--num_out_heads", type=int, default=1, |
| 57 | help="number of output attention heads") |
| 58 | parser.add_argument("--num_layers", type=int, default=2, |
| 59 | help="number of hidden layers") |
| 60 | parser.add_argument("--num_dec_layers", type=int, default=1) |
| 61 | parser.add_argument("--num_remasking", type=int, default=3) |
| 62 | parser.add_argument("--num_hidden", type=int, default=512, |
| 63 | help="number of hidden units") |
| 64 | parser.add_argument("--residual", action="store_true", default=False, |
| 65 | help="use residual connection") |
| 66 | parser.add_argument("--in_drop", type=float, default=.2, |
| 67 | help="input feature dropout") |
| 68 | parser.add_argument("--attn_drop", type=float, default=.1, |
| 69 | help="attention dropout") |
| 70 | parser.add_argument("--norm", type=str, default=None) |
| 71 | parser.add_argument("--lr", type=float, default=0.001, |
| 72 | help="learning rate") |
| 73 | parser.add_argument("--weight_decay", type=float, default=0, |
| 74 | help="weight decay") |
| 75 | parser.add_argument("--negative_slope", type=float, default=0.2, |
| 76 | help="the negative slope of leaky relu") |
| 77 | parser.add_argument("--activation", type=str, default="prelu") |
| 78 | parser.add_argument("--mask_rate", type=float, default=0.5) |
| 79 | parser.add_argument("--remask_rate", type=float, default=0.5) |
| 80 | parser.add_argument("--remask_method", type=str, default="random") |
| 81 | parser.add_argument("--mask_type", type=str, default="mask", |
| 82 | help="`mask` or `drop`") |
| 83 | parser.add_argument("--mask_method", type=str, default="random") |
| 84 | parser.add_argument("--drop_edge_rate", type=float, default=0.0) |
| 85 | parser.add_argument("--drop_edge_rate_f", type=float, default=0.0) |
| 86 | |
| 87 | parser.add_argument("--encoder", type=str, default="gat") |
| 88 | parser.add_argument("--decoder", type=str, default="gat") |
| 89 | parser.add_argument("--loss_fn", type=str, default="sce") |
| 90 | parser.add_argument("--alpha_l", type=float, default=2) |
| 91 | parser.add_argument("--optimizer", type=str, default="adam") |
| 92 | |
| 93 | parser.add_argument("--max_epoch_f", type=int, default=300) |
| 94 | parser.add_argument("--lr_f", type=float, default=0.01) |
| 95 | parser.add_argument("--weight_decay_f", type=float, default=0.0) |
| 96 | parser.add_argument("--linear_prob", action="store_true", default=False) |
| 97 | |
| 98 | |
| 99 | parser.add_argument("--no_pretrain", action="store_true") |
| 100 | parser.add_argument("--load_model", action="store_true") |
| 101 | parser.add_argument("--checkpoint_path", type=str, default=None) |
| 102 | parser.add_argument("--use_cfg", action="store_true") |
no outgoing calls
no test coverage detected