(src, target)
| 52 | # We need to use (1) batch and (2) data parallelism |
| 53 | # http://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html. |
| 54 | def train(src, target): |
| 55 | loss = 0 |
| 56 | |
| 57 | src_var = str2tensor(src) |
| 58 | target_var = str2tensor(target, eos=True) # Add the EOS token |
| 59 | |
| 60 | encoder_hidden = encoder.init_hidden() |
| 61 | encoder_outputs, encoder_hidden = encoder(src_var, encoder_hidden) |
| 62 | |
| 63 | hidden = encoder_hidden |
| 64 | |
| 65 | for c in range(len(target_var)): |
| 66 | # First, we feed SOS. Others, we use teacher forcing. |
| 67 | token = target_var[c - 1] if c else str2tensor(SOS_token) |
| 68 | output, hidden, attention = decoder(token, hidden, encoder_outputs) |
| 69 | loss += criterion(output, target_var[c]) |
| 70 | |
| 71 | encoder.zero_grad() |
| 72 | decoder.zero_grad() |
| 73 | loss.backward() |
| 74 | optimizer.step() |
| 75 | |
| 76 | return loss.data[0] / len(target_var) |
| 77 | |
| 78 | |
| 79 | # Translate the given input |
no test coverage detected