| 451 | self.metrics = self._get_metrics(metrics) |
| 452 | |
| 453 | def _get_optim(self, optimizer): |
| 454 | if isinstance(optimizer, str): |
| 455 | if optimizer == "sgd": |
| 456 | optim = torch.optim.SGD(self.parameters(), lr=0.01) |
| 457 | elif optimizer == "adam": |
| 458 | optim = torch.optim.Adam(self.parameters()) # 0.001 |
| 459 | elif optimizer == "adagrad": |
| 460 | optim = torch.optim.Adagrad(self.parameters()) # 0.01 |
| 461 | elif optimizer == "rmsprop": |
| 462 | optim = torch.optim.RMSprop(self.parameters()) |
| 463 | else: |
| 464 | raise NotImplementedError |
| 465 | else: |
| 466 | optim = optimizer |
| 467 | return optim |
| 468 | |
| 469 | def _get_loss_func(self, loss): |
| 470 | if isinstance(loss, str): |