(args, total)
| 69 | return (epoch, last_global_step, last_global_data_samples) |
| 70 | |
| 71 | def get_effective_batch(args, total): |
| 72 | if args.local_rank != -1: |
| 73 | return total//dist.get_world_size()//args.train_micro_batch_size_per_gpu//args.gradient_accumulation_steps//args.refresh_bucket_size |
| 74 | else: |
| 75 | return total//args.train_micro_batch_size_per_gpu//args.gradient_accumulation_steps//args.refresh_bucket_size |
| 76 | |
| 77 | |
| 78 | def get_dataloader(args, dataset: Dataset, eval_set=False): |