()
| 48 | |
| 49 | |
| 50 | def test_hf_casual_fwd(): |
| 51 | assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" |
| 52 | |
| 53 | sp_size = 8 |
| 54 | dp_size = 1 |
| 55 | ulysses_device_mesh = init_device_mesh(device_type='cuda', |
| 56 | mesh_shape=(dp_size, sp_size), |
| 57 | mesh_dim_names=('dp', 'sp')) |
| 58 | sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) |
| 59 | |
| 60 | batch_size = 1 |
| 61 | seqlen = 128 |
| 62 | response_length = 127 |
| 63 | |
| 64 | for model_name, (config, attn) in test_configs.items(): |
| 65 | # patch before load |
| 66 | attn.forward = patches[model_name] |
| 67 | with torch.device('cuda'): |
| 68 | model = AutoModelForCausalLM.from_config(config=config, |
| 69 | torch_dtype=torch.bfloat16, |
| 70 | attn_implementation='flash_attention_2') |
| 71 | model = model.to(device='cuda') |
| 72 | sync_model_parameters_global(model) |
| 73 | |
| 74 | # different rank will generate different input_ids following fsdp |
| 75 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') |
| 76 | attention_mask = create_random_mask(input_ids=input_ids, |
| 77 | max_ratio_of_left_padding=0, |
| 78 | max_ratio_of_valid_token=0.9, |
| 79 | min_ratio_of_valid_token=0.8) |
| 80 | position_ids = compute_position_id_with_mask( |
| 81 | attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here |
| 82 | |
| 83 | model_inputs = { |
| 84 | 'input_ids': input_ids.cuda(), |
| 85 | 'attention_mask': attention_mask.cuda(), |
| 86 | 'position_ids': position_ids.int().cuda() |
| 87 | } |
| 88 | |
| 89 | model_inputs = DataProto.from_dict(model_inputs) |
| 90 | |
| 91 | # 1. perform ulysses forward |
| 92 | with sharding_manager: |
| 93 | model_inputs = sharding_manager.preprocess_data(model_inputs) |
| 94 | input_ids = model_inputs.batch['input_ids'] |
| 95 | attention_mask = model_inputs.batch['attention_mask'] |
| 96 | position_ids = model_inputs.batch['position_ids'] |
| 97 | input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), |
| 98 | attention_mask) # input_ids_rmpad (total_nnz, ...) |
| 99 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) |
| 100 | # unpad the position_ids to align the rotary |
| 101 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), |
| 102 | indices).transpose(0, 1) |
| 103 | |
| 104 | # slice input tensor for ulysses |
| 105 | # input_ids are padded and sliced |
| 106 | # postition_ids are only padded but not sliced |
| 107 | input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( |
no test coverage detected