(src, target)
| 36 | # We need to use (1) batch and (2) data parallelism |
| 37 | # http://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html. |
| 38 | def train(src, target): |
| 39 | src_var = str2tensor(src) |
| 40 | target_var = str2tensor(target, eos=True) # Add the EOS token |
| 41 | |
| 42 | encoder_hidden = encoder.init_hidden() |
| 43 | encoder_outputs, encoder_hidden = encoder(src_var, encoder_hidden) |
| 44 | |
| 45 | hidden = encoder_hidden |
| 46 | loss = 0 |
| 47 | |
| 48 | for c in range(len(target_var)): |
| 49 | # First, we feed SOS |
| 50 | # Others, we use teacher forcing |
| 51 | token = target_var[c - 1] if c else str2tensor(SOS_token) |
| 52 | output, hidden = decoder(token, hidden) |
| 53 | loss += criterion(output, target_var[c]) |
| 54 | |
| 55 | encoder.zero_grad() |
| 56 | decoder.zero_grad() |
| 57 | loss.backward() |
| 58 | optimizer.step() |
| 59 | |
| 60 | return loss.data[0] / len(target_var) |
| 61 | |
| 62 | |
| 63 | # Translate the given input |
no test coverage detected