(model)
| 46 | |
| 47 | |
| 48 | def _prepare_transformer(model): |
| 49 | embed_dim = model.tok_embeddings.embedding_dim |
| 50 | model.tok_embeddings = nn.Identity() |
| 51 | model.output = nn.Identity() |
| 52 | return model, embed_dim |
| 53 | |
| 54 | |
| 55 | def _create_causal_mask(seq_len: int, device: torch.device): |