(model, args, iteration_time)
| 220 | |
| 221 | |
| 222 | def flops_calculator(model, args, iteration_time): |
| 223 | return # currently broken |
| 224 | gpus_per_model = torch.distributed.get_world_size( |
| 225 | group=mpu.get_model_parallel_group() |
| 226 | ) |
| 227 | |
| 228 | approx_parameters_in_billions = get_parameters_in_billions(model) |
| 229 | |
| 230 | batch_size = args.micro_batch_size * get_num_microbatches() |
| 231 | |
| 232 | giga_flops_per_model_per_train_step = ( |
| 233 | approx_parameters_in_billions * batch_size * args.seq_length * 2.0 * 4.0 |
| 234 | ) |
| 235 | |
| 236 | effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / ( |
| 237 | iteration_time * 1000.0 * gpus_per_model |
| 238 | ) |
| 239 | |
| 240 | print_rank_0( |
| 241 | f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B" |
| 242 | ) |
no test coverage detected