(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None)
| 25 | supported_backend = ['wandb', 'mlflow', 'console'] |
| 26 | |
| 27 | def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): |
| 28 | if isinstance(default_backend, str): |
| 29 | default_backend = [default_backend] |
| 30 | for backend in default_backend: |
| 31 | if backend == 'tracking': |
| 32 | import warnings |
| 33 | warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning) |
| 34 | else: |
| 35 | assert backend in self.supported_backend, f'{backend} is not supported' |
| 36 | |
| 37 | self.logger = {} |
| 38 | |
| 39 | if 'tracking' in default_backend or 'wandb' in default_backend: |
| 40 | import wandb |
| 41 | import os |
| 42 | WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None) |
| 43 | if WANDB_API_KEY: |
| 44 | wandb.login(key=WANDB_API_KEY) |
| 45 | wandb.init(project=project_name, name=experiment_name, config=config) |
| 46 | self.logger['wandb'] = wandb |
| 47 | |
| 48 | if 'mlflow' in default_backend: |
| 49 | import mlflow |
| 50 | mlflow.start_run(run_name=experiment_name) |
| 51 | mlflow.log_params(_compute_mlflow_params_from_objects(config)) |
| 52 | self.logger['mlflow'] = _MlflowLoggingAdapter() |
| 53 | |
| 54 | if 'console' in default_backend: |
| 55 | from verl.utils.logger.aggregate_logger import LocalLogger |
| 56 | self.console_logger = LocalLogger(print_to_console=True) |
| 57 | self.logger['console'] = self.console_logger |
| 58 | |
| 59 | def log(self, data, step, backend=None): |
| 60 | for default_backend, logger_instance in self.logger.items(): |
nothing calls this directly
no test coverage detected