Loads weights from src to dst via in place copy. src is a huggingface gpt2model, while dst is one of our models. dst2src=True loads parameters from our models into huggingface's. ^dst2src is still untested
(src, dst, dst2src=False)
| 412 | |
| 413 | |
| 414 | def load_weights(src, dst, dst2src=False): |
| 415 | """ |
| 416 | Loads weights from src to dst via in place copy. |
| 417 | src is a huggingface gpt2model, while dst is one of our models. |
| 418 | dst2src=True loads parameters from our models into huggingface's. |
| 419 | ^dst2src is still untested |
| 420 | """ |
| 421 | conv_layer = 'Conv1D' in str(type(src)) |
| 422 | for n, p in src.named_parameters(): |
| 423 | if dst2src: |
| 424 | data = dst._parameters[n].data |
| 425 | load = p.data |
| 426 | else: |
| 427 | data = p.data |
| 428 | load = dst._parameters[n].data |
| 429 | if conv_layer and 'weight' in n: |
| 430 | data = data.t().contiguous() |
| 431 | load.copy_(data) |
| 432 | |
| 433 | |
| 434 | # dst._parameters[n].data.copy_(data) |
no test coverage detected