(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
no_visdom: bool)
| 13 | |
| 14 | |
| 15 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, |
| 16 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, |
| 17 | no_visdom: bool): |
| 18 | # Create a dataset and a dataloader |
| 19 | dataset = SpeakerVerificationDataset(clean_data_root) |
| 20 | loader = SpeakerVerificationDataLoader( |
| 21 | dataset, |
| 22 | speakers_per_batch, |
| 23 | utterances_per_speaker, |
| 24 | num_workers=8, |
| 25 | ) |
| 26 | |
| 27 | # Setup the device on which to run the forward pass and the loss. These can be different, |
| 28 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your |
| 29 | # hyperparameters) faster on the CPU. |
| 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 31 | # FIXME: currently, the gradient is None if loss_device is cuda |
| 32 | loss_device = torch.device("cpu") |
| 33 | |
| 34 | # Create the model and the optimizer |
| 35 | model = SpeakerEncoder(device, loss_device) |
| 36 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) |
| 37 | init_step = 1 |
| 38 | |
| 39 | # Configure file path for the model |
| 40 | state_fpath = models_dir.joinpath(run_id + ".pt") |
| 41 | backup_dir = models_dir.joinpath(run_id + "_backups") |
| 42 | |
| 43 | # Load any existing model |
| 44 | if not force_restart: |
| 45 | if state_fpath.exists(): |
| 46 | print("Found existing model \"%s\", loading it and resuming training." % run_id) |
| 47 | checkpoint = torch.load(state_fpath) |
| 48 | init_step = checkpoint["step"] |
| 49 | model.load_state_dict(checkpoint["model_state"]) |
| 50 | optimizer.load_state_dict(checkpoint["optimizer_state"]) |
| 51 | optimizer.param_groups[0]["lr"] = learning_rate_init |
| 52 | else: |
| 53 | print("No model \"%s\" found, starting training from scratch." % run_id) |
| 54 | else: |
| 55 | print("Starting the training from scratch.") |
| 56 | model.train() |
| 57 | |
| 58 | # Initialize the visualization environment |
| 59 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) |
| 60 | vis.log_dataset(dataset) |
| 61 | vis.log_params() |
| 62 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") |
| 63 | vis.log_implementation({"Device": device_name}) |
| 64 | |
| 65 | # Training loop |
| 66 | profiler = Profiler(summarize_every=10, disabled=False) |
| 67 | for step, speaker_batch in enumerate(loader, init_step): |
| 68 | profiler.tick("Blocking, waiting for batch (threaded)") |
| 69 | |
| 70 | # Forward pass |
| 71 | inputs = torch.from_numpy(speaker_batch.data).to(device) |
| 72 | sync(device) |
no test coverage detected