MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / test_hf_casual_fwd

Function test_hf_casual_fwd

tests/model/test_transformers_ulysses.py:50–125  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

48
49
50def test_hf_casual_fwd():
51 assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
52
53 sp_size = 8
54 dp_size = 1
55 ulysses_device_mesh = init_device_mesh(device_type='cuda',
56 mesh_shape=(dp_size, sp_size),
57 mesh_dim_names=('dp', 'sp'))
58 sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)
59
60 batch_size = 1
61 seqlen = 128
62 response_length = 127
63
64 for model_name, (config, attn) in test_configs.items():
65 # patch before load
66 attn.forward = patches[model_name]
67 with torch.device('cuda'):
68 model = AutoModelForCausalLM.from_config(config=config,
69 torch_dtype=torch.bfloat16,
70 attn_implementation='flash_attention_2')
71 model = model.to(device='cuda')
72 sync_model_parameters_global(model)
73
74 # different rank will generate different input_ids following fsdp
75 input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
76 attention_mask = create_random_mask(input_ids=input_ids,
77 max_ratio_of_left_padding=0,
78 max_ratio_of_valid_token=0.9,
79 min_ratio_of_valid_token=0.8)
80 position_ids = compute_position_id_with_mask(
81 attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
82
83 model_inputs = {
84 'input_ids': input_ids.cuda(),
85 'attention_mask': attention_mask.cuda(),
86 'position_ids': position_ids.int().cuda()
87 }
88
89 model_inputs = DataProto.from_dict(model_inputs)
90
91 # 1. perform ulysses forward
92 with sharding_manager:
93 model_inputs = sharding_manager.preprocess_data(model_inputs)
94 input_ids = model_inputs.batch['input_ids']
95 attention_mask = model_inputs.batch['attention_mask']
96 position_ids = model_inputs.batch['position_ids']
97 input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
98 attention_mask) # input_ids_rmpad (total_nnz, ...)
99 input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
100 # unpad the position_ids to align the rotary
101 position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
102 indices).transpose(0, 1)
103
104 # slice input tensor for ulysses
105 # input_ids are padded and sliced
106 # postition_ids are only padded but not sliced
107 input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(

Callers 1

Calls 13

preprocess_dataMethod · 0.95
create_random_maskFunction · 0.90
gather_outpus_and_unpadFunction · 0.90
from_configMethod · 0.80
toMethod · 0.80
from_dictMethod · 0.80

Tested by

no test coverage detected