MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / _get_optim

Method _get_optim

deepctr_torch/models/basemodel.py:453–467  ·  view source on GitHub ↗
(self, optimizer)

Source from the content-addressed store, hash-verified

451 self.metrics = self._get_metrics(metrics)
452
453 def _get_optim(self, optimizer):
454 if isinstance(optimizer, str):
455 if optimizer == "sgd":
456 optim = torch.optim.SGD(self.parameters(), lr=0.01)
457 elif optimizer == "adam":
458 optim = torch.optim.Adam(self.parameters()) # 0.001
459 elif optimizer == "adagrad":
460 optim = torch.optim.Adagrad(self.parameters()) # 0.01
461 elif optimizer == "rmsprop":
462 optim = torch.optim.RMSprop(self.parameters())
463 else:
464 raise NotImplementedError
465 else:
466 optim = optimizer
467 return optim
468
469 def _get_loss_func(self, loss):
470 if isinstance(loss, str):

Callers 1

compileMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected