Parameters ---------- device : str, optional The device to run the model on ('cpu' or 'cuda'). By default, 'cuda' is used.
(self, device="cuda")
| 27 | """ |
| 28 | |
| 29 | def __init__(self, device="cuda"): |
| 30 | """ |
| 31 | Parameters |
| 32 | ---------- |
| 33 | device : str, optional |
| 34 | The device to run the model on ('cpu' or 'cuda'). By default, 'cuda' is used. |
| 35 | """ |
| 36 | self.device = device |
| 37 | self.model = BitNetTransformer(num_tokens=256, dim=512, depth=8) |
| 38 | self.model = AutoregressiveWrapper(self.model, max_seq_len=1024) |
| 39 | self.model.to(self.device) |
| 40 | |
| 41 | def load_model(self, model_path): |
| 42 | """Loads a trained model from a .pth file.""" |
nothing calls this directly
no test coverage detected