Main entry point to get the config for the model. Args: model_name (str): name of the desired model. mode (str, optional): "train" or "infer". Defaults to 'train'. dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults
(model_name, mode='train', dataset=None, **overwrite_kwargs)
| 352 | |
| 353 | |
| 354 | def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): |
| 355 | """Main entry point to get the config for the model. |
| 356 | |
| 357 | Args: |
| 358 | model_name (str): name of the desired model. |
| 359 | mode (str, optional): "train" or "infer". Defaults to 'train'. |
| 360 | dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. |
| 361 | |
| 362 | Keyword Args: key-value pairs of arguments to overwrite the default config. |
| 363 | |
| 364 | The order of precedence for overwriting the config is (Higher precedence first): |
| 365 | # 1. overwrite_kwargs |
| 366 | # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json |
| 367 | # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json |
| 368 | # 4. common_config: Default config for all models specified in COMMON_CONFIG |
| 369 | |
| 370 | Returns: |
| 371 | easydict: The config dictionary for the model. |
| 372 | """ |
| 373 | |
| 374 | |
| 375 | check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) |
| 376 | check_choices("Mode", mode, ["train", "infer", "eval"]) |
| 377 | if mode == "train": |
| 378 | check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) |
| 379 | |
| 380 | config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) |
| 381 | config = update_model_config(config, mode, model_name) |
| 382 | |
| 383 | # update with model version specific config |
| 384 | version_name = overwrite_kwargs.get("version_name", config["version_name"]) |
| 385 | config = update_model_config(config, mode, model_name, version_name) |
| 386 | |
| 387 | # update with config version if specified |
| 388 | config_version = overwrite_kwargs.get("config_version", None) |
| 389 | if config_version is not None: |
| 390 | print("Overwriting config with config_version", config_version) |
| 391 | config = update_model_config(config, mode, model_name, config_version) |
| 392 | |
| 393 | # update with overwrite_kwargs |
| 394 | # Combined args are useful for hyperparameter search |
| 395 | overwrite_kwargs = split_combined_args(overwrite_kwargs) |
| 396 | config = {**config, **overwrite_kwargs} |
| 397 | |
| 398 | # Casting to bool # TODO: Not necessary. Remove and test |
| 399 | for key in KEYS_TYPE_BOOL: |
| 400 | if key in config: |
| 401 | config[key] = bool(config[key]) |
| 402 | |
| 403 | # Model specific post processing of config |
| 404 | parse_list(config, "n_attractors") |
| 405 | |
| 406 | # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs |
| 407 | if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: |
| 408 | bin_conf = config['bin_conf'] # list of dicts |
| 409 | n_bins = overwrite_kwargs['n_bins'] |
| 410 | new_bin_conf = [] |
| 411 | for conf in bin_conf: |
no test coverage detected