()
| 42 | |
| 43 | |
| 44 | def test_fused_softmax(): |
| 45 | from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes |
| 46 | from megatron.model.gpt2_model import ( |
| 47 | gpt2_attention_mask_func as attention_mask_func, |
| 48 | ) |
| 49 | |
| 50 | bert = BertModel.from_pretrained("bert-base-cased").cuda().half() |
| 51 | tokenizer = BertTokenizer.from_pretrained("bert-base-cased") |
| 52 | test_text = ( |
| 53 | "Hello. How are you? I am fine thank you and you? yes Good. " |
| 54 | "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 |
| 55 | ) |
| 56 | |
| 57 | tokens = tokenizer( |
| 58 | [test_text] * 4, |
| 59 | return_tensors="pt", |
| 60 | ) |
| 61 | |
| 62 | embedding_output = bert.embeddings( |
| 63 | input_ids=tokens["input_ids"].cuda(), |
| 64 | position_ids=None, |
| 65 | token_type_ids=tokens["token_type_ids"].cuda(), |
| 66 | inputs_embeds=None, |
| 67 | past_key_values_length=0, |
| 68 | ) |
| 69 | |
| 70 | # (bsz, 1, 1, seq_len) |
| 71 | mask = bert.get_extended_attention_mask( |
| 72 | attention_mask=tokens["attention_mask"].cuda(), |
| 73 | input_shape=tokens["input_ids"].shape, |
| 74 | device=bert.device, |
| 75 | ) |
| 76 | # (bsz, 1, seq_len, seq_len) |
| 77 | mask = mask.repeat(1, 1, mask.size()[-1], 1) |
| 78 | |
| 79 | attention = bert.encoder.layer[0].attention.self |
| 80 | key_layer = attention.transpose_for_scores(attention.key(embedding_output)) |
| 81 | query_layer = attention.transpose_for_scores(attention.query(embedding_output)) |
| 82 | |
| 83 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| 84 | attention_scores /= math.sqrt(key_layer.size()[-1]) |
| 85 | |
| 86 | fused_softmax = ( |
| 87 | FusedScaleMaskSoftmax( |
| 88 | input_in_fp16=True, |
| 89 | input_in_bf16=False, |
| 90 | fusion_type=SoftmaxFusionTypes.general, |
| 91 | mask_func=attention_mask_func, |
| 92 | scale=None, |
| 93 | softmax_in_fp32=False, |
| 94 | ) |
| 95 | .cuda() |
| 96 | .half() |
| 97 | ) |
| 98 | |
| 99 | fused_softmax_output = fused_softmax( |
| 100 | attention_scores, |
| 101 | (mask != 0), |
nothing calls this directly
no test coverage detected