MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / _transpose_first_dim

Function _transpose_first_dim

codegeex/megatron/checkpointing.py:205–246  ·  view source on GitHub ↗
(t, num_splits, num_splits_first, model)

Source from the content-addressed store, hash-verified

203
204
205def _transpose_first_dim(t, num_splits, num_splits_first, model):
206 input_shape = t.size()
207 # We use a self_attention module but the values extracted aren't
208 # specific to self attention so should work for cross attention as well
209 while hasattr(model, "module"):
210 model = model.module
211 attention_module = model.language_model.encoder.layers[0].self_attention
212 hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
213 num_attention_heads_per_partition = (
214 attention_module.num_attention_heads_per_partition
215 )
216 if num_splits_first:
217 """[num_splits * np * hn, h]
218 -->(view) [num_splits, np, hn, h]
219 -->(tranpose) [np, num_splits, hn, h]
220 -->(view) [np * num_splits * hn, h]"""
221
222 intermediate_shape = (
223 num_splits,
224 num_attention_heads_per_partition,
225 hidden_size_per_attention_head,
226 ) + input_shape[1:]
227
228 t = t.view(*intermediate_shape)
229 t = t.transpose(0, 1).contiguous()
230 else:
231 """[np * hn * num_splits, h]
232 -->(view) [np, hn, num_splits, h]
233 -->(tranpose) [np, num_splits, hn, h]
234 -->(view) [np * num_splits * hn, h]"""
235
236 intermediate_shape = (
237 num_attention_heads_per_partition,
238 hidden_size_per_attention_head,
239 num_splits,
240 ) + input_shape[1:]
241
242 t = t.view(*intermediate_shape)
243 t = t.transpose(1, 2).contiguous()
244 t = t.view(*input_shape)
245
246 return t
247
248
249def fix_query_key_value_ordering(model, checkpoint_version):

Callers 1

Calls 1

sizeMethod · 0.80

Tested by

no test coverage detected