(self, label_map, logger, evaluator, conf, loss_fn)
| 86 | |
| 87 | class ClassificationTrainer(object): |
| 88 | def __init__(self, label_map, logger, evaluator, conf, loss_fn): |
| 89 | self.label_map = label_map |
| 90 | self.logger = logger |
| 91 | self.evaluator = evaluator |
| 92 | self.conf = conf |
| 93 | self.loss_fn = loss_fn |
| 94 | if self.conf.task_info.hierarchical: |
| 95 | self.hierar_relations = get_hierar_relations( |
| 96 | self.conf.task_info.hierar_taxonomy, label_map) |
| 97 | |
| 98 | def train(self, data_loader, model, optimizer, stage, epoch): |
| 99 | model.update_lr(optimizer, epoch) |
nothing calls this directly
no test coverage detected