MCPcopy
hub / github.com/THUDM/CogDL / raw_experiment

Function raw_experiment

cogdl/experiments.py:269–299  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

267
268
269def raw_experiment(args):
270 variants = list(gen_variants(dataset=args.dataset, model=args.model, seed=args.seed, split=args.split))
271
272 results_dict = defaultdict(list)
273 if len(args.devices) == 1 or args.cpu or args.distributed:
274 results = [train(args) for args in variant_args_generator(args, variants)]
275 for variant, result in zip(variants, results):
276 results_dict[variant[:-2]].append(result)
277 else:
278 mp.set_start_method("spawn", force=True)
279
280 # Make sure datasets are downloaded first
281 datasets = args.dataset
282 for dataset in datasets:
283 args.dataset = dataset
284 build_dataset(args)
285 args.dataset = datasets
286
287 num_workers = len(args.devices)
288 with mp.Pool(processes=num_workers) as pool:
289 pids = pool.map(getpid, range(num_workers))
290 args.pid_to_cuda = dict(zip(pids, args.devices))
291
292 results = pool.map(train_parallel, variant_args_generator(args, variants))
293 for variant, result in zip(variants, results):
294 results_dict[variant[:-2]].append(result)
295
296 tablefmt = args.tablefmt if hasattr(args, "tablefmt") else "github"
297 output_results(results_dict, tablefmt)
298
299 return results_dict
300
301
302def auto_experiment(args):

Callers 2

_objectiveMethod · 0.70
experimentFunction · 0.70

Calls 5

build_datasetFunction · 0.90
gen_variantsFunction · 0.70
trainFunction · 0.70
variant_args_generatorFunction · 0.70
output_resultsFunction · 0.70

Tested by

no test coverage detected