MCPcopy
hub / github.com/babysor/MockingBird / train

Function train

encoder/train.py:15–123  ·  view source on GitHub ↗
(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
          backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
          no_visdom: bool)

Source from the content-addressed store, hash-verified

13
14
15def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
16 backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
17 no_visdom: bool):
18 # Create a dataset and a dataloader
19 dataset = SpeakerVerificationDataset(clean_data_root)
20 loader = SpeakerVerificationDataLoader(
21 dataset,
22 speakers_per_batch,
23 utterances_per_speaker,
24 num_workers=8,
25 )
26
27 # Setup the device on which to run the forward pass and the loss. These can be different,
28 # because the forward pass is faster on the GPU whereas the loss is often (depending on your
29 # hyperparameters) faster on the CPU.
30 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31 # FIXME: currently, the gradient is None if loss_device is cuda
32 loss_device = torch.device("cpu")
33
34 # Create the model and the optimizer
35 model = SpeakerEncoder(device, loss_device)
36 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
37 init_step = 1
38
39 # Configure file path for the model
40 state_fpath = models_dir.joinpath(run_id + ".pt")
41 backup_dir = models_dir.joinpath(run_id + "_backups")
42
43 # Load any existing model
44 if not force_restart:
45 if state_fpath.exists():
46 print("Found existing model \"%s\", loading it and resuming training." % run_id)
47 checkpoint = torch.load(state_fpath)
48 init_step = checkpoint["step"]
49 model.load_state_dict(checkpoint["model_state"])
50 optimizer.load_state_dict(checkpoint["optimizer_state"])
51 optimizer.param_groups[0]["lr"] = learning_rate_init
52 else:
53 print("No model \"%s\" found, starting training from scratch." % run_id)
54 else:
55 print("Starting the training from scratch.")
56 model.train()
57
58 # Initialize the visualization environment
59 vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
60 vis.log_dataset(dataset)
61 vis.log_params()
62 device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
63 vis.log_implementation({"Device": device_name})
64
65 # Training loop
66 profiler = Profiler(summarize_every=10, disabled=False)
67 for step, speaker_batch in enumerate(loader, init_step):
68 profiler.tick("Blocking, waiting for batch (threaded)")
69
70 # Forward pass
71 inputs = torch.from_numpy(speaker_batch.data).to(device)
72 sync(device)

Callers 3

encoder_train.pyFile · 0.90
vocoder_train.pyFile · 0.90

Calls 15

log_datasetMethod · 0.95
log_paramsMethod · 0.95
log_implementationMethod · 0.95
tickMethod · 0.95
lossMethod · 0.95
do_gradient_opsMethod · 0.95
updateMethod · 0.95
draw_projectionsMethod · 0.95
saveMethod · 0.95
SpeakerEncoderClass · 0.90

Tested by

no test coverage detected