(config, params)
| 121 | |
| 122 | |
| 123 | def get_optimizer(config, params): |
| 124 | params = params.get_parameter_optimizer_dict() |
| 125 | if config.optimizer.optimizer_type == OptimizerType.ADAM: |
| 126 | return torch.optim.Adam(lr=config.optimizer.learning_rate, |
| 127 | params=params) |
| 128 | elif config.optimizer.optimizer_type == OptimizerType.ADADELTA: |
| 129 | return torch.optim.Adadelta( |
| 130 | lr=config.optimizer.learning_rate, |
| 131 | rho=config.optimizer.adadelta_decay_rate, |
| 132 | eps=config.optimizer.adadelta_epsilon, |
| 133 | params=params) |
| 134 | elif config.optimizer.optimizer_type == OptimizerType.BERT_ADAM: |
| 135 | return BertAdam(params, |
| 136 | lr=config.optimizer.learning_rate, |
| 137 | weight_decay=0, max_grad_norm=-1) |
| 138 | else: |
| 139 | raise TypeError( |
| 140 | "Unsupported tensor optimizer type: %s.Supported optimizer " |
| 141 | "type is: %s" % (config.optimizer_type, OptimizerType.str())) |
| 142 | |
| 143 | |
| 144 | def get_hierar_relations(hierar_taxonomy, label_map): |
no test coverage detected