| 45 | return mod |
| 46 | |
| 47 | class PreModel(nn.Module): |
| 48 | def __init__( |
| 49 | self, |
| 50 | in_dim: int, |
| 51 | num_hidden: int, |
| 52 | num_layers: int, |
| 53 | num_dec_layers: int, |
| 54 | num_remasking: int, |
| 55 | nhead: int, |
| 56 | nhead_out: int, |
| 57 | activation: str, |
| 58 | feat_drop: float, |
| 59 | attn_drop: float, |
| 60 | negative_slope: float, |
| 61 | residual: bool, |
| 62 | norm: Optional[str], |
| 63 | mask_rate: float = 0.3, |
| 64 | remask_rate: float = 0.5, |
| 65 | remask_method: str = "random", |
| 66 | mask_method: str = "random", |
| 67 | encoder_type: str = "gat", |
| 68 | decoder_type: str = "gat", |
| 69 | loss_fn: str = "byol", |
| 70 | drop_edge_rate: float = 0.0, |
| 71 | alpha_l: float = 2, |
| 72 | lam: float = 1.0, |
| 73 | delayed_ema_epoch: int = 0, |
| 74 | momentum: float = 0.996, |
| 75 | replace_rate: float = 0.0, |
| 76 | ): |
| 77 | super(PreModel, self).__init__() |
| 78 | self._mask_rate = mask_rate |
| 79 | self._remask_rate = remask_rate |
| 80 | self._mask_method = mask_method |
| 81 | self._alpha_l = alpha_l |
| 82 | self._delayed_ema_epoch = delayed_ema_epoch |
| 83 | |
| 84 | self.num_remasking = num_remasking |
| 85 | self._encoder_type = encoder_type |
| 86 | self._decoder_type = decoder_type |
| 87 | self._drop_edge_rate = drop_edge_rate |
| 88 | self._output_hidden_size = num_hidden |
| 89 | self._momentum = momentum |
| 90 | self._replace_rate = replace_rate |
| 91 | self._num_remasking = num_remasking |
| 92 | self._remask_method = remask_method |
| 93 | |
| 94 | self._token_rate = 1 - self._replace_rate |
| 95 | self._lam = lam |
| 96 | |
| 97 | assert num_hidden % nhead == 0 |
| 98 | assert num_hidden % nhead_out == 0 |
| 99 | if encoder_type in ("gat",): |
| 100 | enc_num_hidden = num_hidden // nhead |
| 101 | enc_nhead = nhead |
| 102 | else: |
| 103 | enc_num_hidden = num_hidden |
| 104 | enc_nhead = 1 |