MCPcopy
hub / github.com/EleutherAI/gpt-neox / test_fused_softmax

Function test_fused_softmax

tests/model/test_fused_kernels.py:44–142  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

42
43
44def test_fused_softmax():
45 from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
46 from megatron.model.gpt2_model import (
47 gpt2_attention_mask_func as attention_mask_func,
48 )
49
50 bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
51 tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
52 test_text = (
53 "Hello. How are you? I am fine thank you and you? yes Good. "
54 "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
55 )
56
57 tokens = tokenizer(
58 [test_text] * 4,
59 return_tensors="pt",
60 )
61
62 embedding_output = bert.embeddings(
63 input_ids=tokens["input_ids"].cuda(),
64 position_ids=None,
65 token_type_ids=tokens["token_type_ids"].cuda(),
66 inputs_embeds=None,
67 past_key_values_length=0,
68 )
69
70 # (bsz, 1, 1, seq_len)
71 mask = bert.get_extended_attention_mask(
72 attention_mask=tokens["attention_mask"].cuda(),
73 input_shape=tokens["input_ids"].shape,
74 device=bert.device,
75 )
76 # (bsz, 1, seq_len, seq_len)
77 mask = mask.repeat(1, 1, mask.size()[-1], 1)
78
79 attention = bert.encoder.layer[0].attention.self
80 key_layer = attention.transpose_for_scores(attention.key(embedding_output))
81 query_layer = attention.transpose_for_scores(attention.query(embedding_output))
82
83 attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
84 attention_scores /= math.sqrt(key_layer.size()[-1])
85
86 fused_softmax = (
87 FusedScaleMaskSoftmax(
88 input_in_fp16=True,
89 input_in_bf16=False,
90 fusion_type=SoftmaxFusionTypes.general,
91 mask_func=attention_mask_func,
92 scale=None,
93 softmax_in_fp32=False,
94 )
95 .cuda()
96 .half()
97 )
98
99 fused_softmax_output = fused_softmax(
100 attention_scores,
101 (mask != 0),

Callers

nothing calls this directly

Calls 3

from_pretrainedMethod · 0.80
sizeMethod · 0.80

Tested by

no test coverage detected