(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs)
| 79 | class BGRL(nn.Module): |
| 80 | |
| 81 | def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs): |
| 82 | super().__init__() |
| 83 | self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs) |
| 84 | self.teacher_encoder = copy.deepcopy(self.student_encoder) |
| 85 | set_requires_grad(self.teacher_encoder, False) |
| 86 | self.teacher_ema_updater = EMA(moving_average_decay, epochs) |
| 87 | rep_dim = layer_config[-1] |
| 88 | self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim)) |
| 89 | self.student_predictor.apply(init_weights) |
| 90 | |
| 91 | def reset_moving_average(self): |
| 92 | del self.teacher_encoder |
nothing calls this directly
no test coverage detected