(model: PretrainedModel)
| 1513 | |
| 1514 | |
| 1515 | def parallelize_embedding(model: PretrainedModel) -> PretrainedModel: |
| 1516 | for name, embedding, parent in model.named_modules_with_parent(): |
| 1517 | layer_name = name.rsplit('.', 1)[-1] |
| 1518 | if isinstance(embedding, Embedding) and embedding.tp_group is None: |
| 1519 | init_params = get_init_params(embedding) |
| 1520 | init_params["tp_group"] = model.config.mapping.tp_group |
| 1521 | init_params["tp_size"] = model.config.mapping.tp_size |
| 1522 | init_params["tp_rank"] = model.config.mapping.tp_rank |
| 1523 | init_params["sharding_dim"] = model.config.embedding_sharding_dim |
| 1524 | new_embedding = embedding.__class__(**init_params) |
| 1525 | setattr(parent, layer_name, new_embedding) |
| 1526 | return model |
| 1527 | |
| 1528 | |
| 1529 | def share_embedding(model: PretrainedModel) -> PretrainedModel: |
no test coverage detected