MCPcopy
hub / github.com/THUDM/CogDL / experiment

Function experiment

cogdl/experiments.py:335–370  ·  view source on GitHub ↗
(dataset, model=None, **kwargs)

Source from the content-addressed store, hash-verified

333
334
335def experiment(dataset, model=None, **kwargs):
336 if model is None:
337 model = "autognn"
338 if isinstance(dataset, str) or isinstance(dataset, Dataset):
339 dataset = [dataset]
340 if isinstance(model, str) or isinstance(model, nn.Module):
341 model = [model]
342 if "args" not in kwargs:
343 args = get_default_args(dataset=[str(x) for x in dataset], model=[str(x) for x in model], **kwargs)
344 else:
345 args = kwargs["args"]
346 for key, value in kwargs.items():
347 if key != "args":
348 args.__setattr__(key, value)
349 if isinstance(model[0], nn.Module):
350 args.model = [x.model_name for x in model]
351 print(args)
352 args.dataset = dataset
353 args.model = model
354
355 if args.max_epoch is not None:
356 warnings.warn("The max_epoch is deprecated and will be removed in the future, please use epochs instead!")
357 args.epochs = args.max_epoch
358
359 if len(model) == 1 and isinstance(model[0], str) and model[0] == "autognn":
360 if not hasattr(args, "search_space"):
361 args.search_space = default_search_space
362 if not hasattr(args, "seed"):
363 args.seed = [1, 2]
364 if not hasattr(args, "n_trials"):
365 args.n_trials = 20
366
367 if hasattr(args, "search_space"):
368 return auto_experiment(args)
369
370 return raw_experiment(args)

Callers 15

train.pyFile · 0.90
test_experimentFunction · 0.90
test_auto_experimentFunction · 0.90
test_autognn_experimentFunction · 0.90
2training.pyFile · 0.90
3custom_dataset.pyFile · 0.90
4custom_gnn.pyFile · 0.90
2training_cn.pyFile · 0.90
custom_dataset.pyFile · 0.90

Calls 3

get_default_argsFunction · 0.90
auto_experimentFunction · 0.85
raw_experimentFunction · 0.70

Tested by 5

test_experimentFunction · 0.72
test_auto_experimentFunction · 0.72
test_autognn_experimentFunction · 0.72