(args)
| 267 | |
| 268 | |
| 269 | def raw_experiment(args): |
| 270 | variants = list(gen_variants(dataset=args.dataset, model=args.model, seed=args.seed, split=args.split)) |
| 271 | |
| 272 | results_dict = defaultdict(list) |
| 273 | if len(args.devices) == 1 or args.cpu or args.distributed: |
| 274 | results = [train(args) for args in variant_args_generator(args, variants)] |
| 275 | for variant, result in zip(variants, results): |
| 276 | results_dict[variant[:-2]].append(result) |
| 277 | else: |
| 278 | mp.set_start_method("spawn", force=True) |
| 279 | |
| 280 | # Make sure datasets are downloaded first |
| 281 | datasets = args.dataset |
| 282 | for dataset in datasets: |
| 283 | args.dataset = dataset |
| 284 | build_dataset(args) |
| 285 | args.dataset = datasets |
| 286 | |
| 287 | num_workers = len(args.devices) |
| 288 | with mp.Pool(processes=num_workers) as pool: |
| 289 | pids = pool.map(getpid, range(num_workers)) |
| 290 | args.pid_to_cuda = dict(zip(pids, args.devices)) |
| 291 | |
| 292 | results = pool.map(train_parallel, variant_args_generator(args, variants)) |
| 293 | for variant, result in zip(variants, results): |
| 294 | results_dict[variant[:-2]].append(result) |
| 295 | |
| 296 | tablefmt = args.tablefmt if hasattr(args, "tablefmt") else "github" |
| 297 | output_results(results_dict, tablefmt) |
| 298 | |
| 299 | return results_dict |
| 300 | |
| 301 | |
| 302 | def auto_experiment(args): |
no test coverage detected