Loads weights from src to dst via in place copy. src is a huggingface gpt2model, while dst is one of our models. dst2src=True loads parameters from our models into huggingface's. ^dst2src is still untested
(src, dst, dst2src=False)
| 381 | |
| 382 | |
| 383 | def load_weights(src, dst, dst2src=False): |
| 384 | """ |
| 385 | Loads weights from src to dst via in place copy. |
| 386 | src is a huggingface gpt2model, while dst is one of our models. |
| 387 | dst2src=True loads parameters from our models into huggingface's. |
| 388 | ^dst2src is still untested |
| 389 | """ |
| 390 | conv_layer = 'Conv1D' in str(type(src)) |
| 391 | for n, p in src.named_parameters(): |
| 392 | if dst2src: |
| 393 | data = dst._parameters[n].data |
| 394 | load = p.data |
| 395 | else: |
| 396 | data = p.data |
| 397 | load = dst._parameters[n].data |
| 398 | if conv_layer and 'weight' in n: |
| 399 | data = data.t().contiguous() |
| 400 | load.copy_(data) |
| 401 | |
| 402 | |
| 403 | # dst._parameters[n].data.copy_(data) |
no outgoing calls
no test coverage detected