| 333 | |
| 334 | |
| 335 | def experiment(dataset, model=None, **kwargs): |
| 336 | if model is None: |
| 337 | model = "autognn" |
| 338 | if isinstance(dataset, str) or isinstance(dataset, Dataset): |
| 339 | dataset = [dataset] |
| 340 | if isinstance(model, str) or isinstance(model, nn.Module): |
| 341 | model = [model] |
| 342 | if "args" not in kwargs: |
| 343 | args = get_default_args(dataset=[str(x) for x in dataset], model=[str(x) for x in model], **kwargs) |
| 344 | else: |
| 345 | args = kwargs["args"] |
| 346 | for key, value in kwargs.items(): |
| 347 | if key != "args": |
| 348 | args.__setattr__(key, value) |
| 349 | if isinstance(model[0], nn.Module): |
| 350 | args.model = [x.model_name for x in model] |
| 351 | print(args) |
| 352 | args.dataset = dataset |
| 353 | args.model = model |
| 354 | |
| 355 | if args.max_epoch is not None: |
| 356 | warnings.warn("The max_epoch is deprecated and will be removed in the future, please use epochs instead!") |
| 357 | args.epochs = args.max_epoch |
| 358 | |
| 359 | if len(model) == 1 and isinstance(model[0], str) and model[0] == "autognn": |
| 360 | if not hasattr(args, "search_space"): |
| 361 | args.search_space = default_search_space |
| 362 | if not hasattr(args, "seed"): |
| 363 | args.seed = [1, 2] |
| 364 | if not hasattr(args, "n_trials"): |
| 365 | args.n_trials = 20 |
| 366 | |
| 367 | if hasattr(args, "search_space"): |
| 368 | return auto_experiment(args) |
| 369 | |
| 370 | return raw_experiment(args) |