MCPcopy
hub / github.com/THUDM/CogDL / PreModel

Class PreModel

examples/graphmae2/models/edcoder.py:47–364  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

45 return mod
46
47class PreModel(nn.Module):
48 def __init__(
49 self,
50 in_dim: int,
51 num_hidden: int,
52 num_layers: int,
53 num_dec_layers: int,
54 num_remasking: int,
55 nhead: int,
56 nhead_out: int,
57 activation: str,
58 feat_drop: float,
59 attn_drop: float,
60 negative_slope: float,
61 residual: bool,
62 norm: Optional[str],
63 mask_rate: float = 0.3,
64 remask_rate: float = 0.5,
65 remask_method: str = "random",
66 mask_method: str = "random",
67 encoder_type: str = "gat",
68 decoder_type: str = "gat",
69 loss_fn: str = "byol",
70 drop_edge_rate: float = 0.0,
71 alpha_l: float = 2,
72 lam: float = 1.0,
73 delayed_ema_epoch: int = 0,
74 momentum: float = 0.996,
75 replace_rate: float = 0.0,
76 ):
77 super(PreModel, self).__init__()
78 self._mask_rate = mask_rate
79 self._remask_rate = remask_rate
80 self._mask_method = mask_method
81 self._alpha_l = alpha_l
82 self._delayed_ema_epoch = delayed_ema_epoch
83
84 self.num_remasking = num_remasking
85 self._encoder_type = encoder_type
86 self._decoder_type = decoder_type
87 self._drop_edge_rate = drop_edge_rate
88 self._output_hidden_size = num_hidden
89 self._momentum = momentum
90 self._replace_rate = replace_rate
91 self._num_remasking = num_remasking
92 self._remask_method = remask_method
93
94 self._token_rate = 1 - self._replace_rate
95 self._lam = lam
96
97 assert num_hidden % nhead == 0
98 assert num_hidden % nhead_out == 0
99 if encoder_type in ("gat",):
100 enc_num_hidden = num_hidden // nhead
101 enc_nhead = nhead
102 else:
103 enc_num_hidden = num_hidden
104 enc_nhead = 1

Callers 1

build_modelFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected