(type, *args, **kwargs)
| 176 | |
| 177 | |
| 178 | def norm(type, *args, **kwargs): |
| 179 | if type == "rmsnorm": |
| 180 | return RMSNorm(*args, **kwargs) |
| 181 | elif type == "layernorm": |
| 182 | return nn.LayerNorm(*args, **kwargs) |
| 183 | else: |
| 184 | raise ValueError(f"Unknown norm type {type}") |
| 185 | |
| 186 | |
| 187 | def dot_product_attention_weights( |