()
| 287 | |
| 288 | |
| 289 | def main(): |
| 290 | args = parse_args() |
| 291 | disable_existing_loggers() |
| 292 | colossalai.legacy.launch_from_torch() |
| 293 | logger = get_dist_logger() |
| 294 | is_main_process = dist.get_rank() == 0 |
| 295 | |
| 296 | if is_main_process: |
| 297 | datasets.utils.logging.set_verbosity_warning() |
| 298 | logging.set_verbosity_info() |
| 299 | else: |
| 300 | datasets.utils.logging.set_verbosity_error() |
| 301 | logging.set_verbosity_error() |
| 302 | |
| 303 | if args.mem_cap > 0: |
| 304 | colo_memory_cap(args.mem_cap) |
| 305 | |
| 306 | # If passed along, set the training seed now. |
| 307 | if args.seed is not None: |
| 308 | set_seed(args.seed) |
| 309 | logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") |
| 310 | |
| 311 | # Handle the repository creation |
| 312 | with barrier_context(): |
| 313 | if args.output_dir is not None: |
| 314 | os.makedirs(args.output_dir, exist_ok=True) |
| 315 | |
| 316 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) |
| 317 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ |
| 318 | # (the dataset will be downloaded automatically from the datasets Hub). |
| 319 | # |
| 320 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called |
| 321 | # 'text' is found. You can easily tweak this behavior (see below). |
| 322 | # |
| 323 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently |
| 324 | # download the dataset. |
| 325 | logger.info("Start preparing dataset", ranks=[0]) |
| 326 | if not args.synthetic: |
| 327 | if args.dataset_name is not None: |
| 328 | # Downloading and loading a dataset from the hub. |
| 329 | raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) |
| 330 | if "validation" not in raw_datasets.keys(): |
| 331 | raw_datasets["validation"] = load_dataset( |
| 332 | args.dataset_name, |
| 333 | args.dataset_config_name, |
| 334 | split=f"train[:{args.validation_split_percentage}%]", |
| 335 | ) |
| 336 | raw_datasets["train"] = load_dataset( |
| 337 | args.dataset_name, |
| 338 | args.dataset_config_name, |
| 339 | split=f"train[{args.validation_split_percentage}%:]", |
| 340 | ) |
| 341 | else: |
| 342 | data_files = {} |
| 343 | dataset_args = {} |
| 344 | if args.train_file is not None: |
| 345 | data_files["train"] = args.train_file |
| 346 | if args.validation_file is not None: |
no test coverage detected
searching dependent graphs…