MCPcopy
hub / github.com/Jiayi-Pan/TinyZero / test_hf_casual_models

Function test_hf_casual_models

tests/model/test_transformer.py:32–88  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

30
31
32def test_hf_casual_models():
33 batch_size = 4
34 seqlen = 128
35 response_length = 127
36
37 for config in test_configs:
38 # config = AutoConfig.from_pretrained(test_case)
39 with torch.device('cuda'):
40 model = AutoModelForCausalLM.from_config(config=config,
41 torch_dtype=torch.bfloat16,
42 attn_implementation='flash_attention_2')
43 model = model.to(device='cuda')
44 input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
45 attention_mask = create_random_mask(input_ids=input_ids,
46 max_ratio_of_left_padding=0.1,
47 max_ratio_of_valid_token=0.8,
48 min_ratio_of_valid_token=0.5)
49 position_ids = compute_position_id_with_mask(
50 attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
51
52 input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
53 attention_mask) # input_ids_rmpad (total_nnz, ...)
54 input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
55
56 # unpad the position_ids to align the rotary
57 position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
58 indices).transpose(0, 1)
59
60 # input with input_ids_rmpad and postition_ids to enable flash attention varlen
61 logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad,
62 use_cache=False).logits # (1, total_nnz, vocab_size)
63
64 origin_logits = model(input_ids=input_ids,
65 attention_mask=attention_mask,
66 position_ids=position_ids,
67 use_cache=False).logits
68 origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)
69
70 logits_rmpad = logits_rmpad.squeeze(0)
71 log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
72 logits_rmpad=logits_rmpad,
73 indices=indices,
74 batch_size=batch_size,
75 seqlen=seqlen,
76 response_length=response_length) # (batch, seqlen)
77 origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
78 logits_rmpad=origin_logits_rmpad,
79 indices=origin_logits_indices,
80 batch_size=batch_size,
81 seqlen=seqlen,
82 response_length=response_length) # (batch, seqlen)
83
84 torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]),
85 masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]),
86 atol=1e-2,
87 rtol=1e-5)
88 print(f'Check pass')
89

Callers 1

Calls 6

create_random_maskFunction · 0.90
masked_meanFunction · 0.90
from_configMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected