(t, num_splits, num_splits_first, model)
| 203 | |
| 204 | |
| 205 | def _transpose_first_dim(t, num_splits, num_splits_first, model): |
| 206 | input_shape = t.size() |
| 207 | # We use a self_attention module but the values extracted aren't |
| 208 | # specific to self attention so should work for cross attention as well |
| 209 | while hasattr(model, "module"): |
| 210 | model = model.module |
| 211 | attention_module = model.language_model.encoder.layers[0].self_attention |
| 212 | hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head |
| 213 | num_attention_heads_per_partition = ( |
| 214 | attention_module.num_attention_heads_per_partition |
| 215 | ) |
| 216 | if num_splits_first: |
| 217 | """[num_splits * np * hn, h] |
| 218 | -->(view) [num_splits, np, hn, h] |
| 219 | -->(tranpose) [np, num_splits, hn, h] |
| 220 | -->(view) [np * num_splits * hn, h]""" |
| 221 | |
| 222 | intermediate_shape = ( |
| 223 | num_splits, |
| 224 | num_attention_heads_per_partition, |
| 225 | hidden_size_per_attention_head, |
| 226 | ) + input_shape[1:] |
| 227 | |
| 228 | t = t.view(*intermediate_shape) |
| 229 | t = t.transpose(0, 1).contiguous() |
| 230 | else: |
| 231 | """[np * hn * num_splits, h] |
| 232 | -->(view) [np, hn, num_splits, h] |
| 233 | -->(tranpose) [np, num_splits, hn, h] |
| 234 | -->(view) [np * num_splits * hn, h]""" |
| 235 | |
| 236 | intermediate_shape = ( |
| 237 | num_attention_heads_per_partition, |
| 238 | hidden_size_per_attention_head, |
| 239 | num_splits, |
| 240 | ) + input_shape[1:] |
| 241 | |
| 242 | t = t.view(*intermediate_shape) |
| 243 | t = t.transpose(1, 2).contiguous() |
| 244 | t = t.view(*input_shape) |
| 245 | |
| 246 | return t |
| 247 | |
| 248 | |
| 249 | def fix_query_key_value_ordering(model, checkpoint_version): |
no test coverage detected