Single-GPU single-cost single-tower trainer.
| 41 | |
| 42 | |
| 43 | class SimpleTrainer(SingleCostTrainer): |
| 44 | """ |
| 45 | Single-GPU single-cost single-tower trainer. |
| 46 | """ |
| 47 | def _setup_graph(self, input, get_cost_fn, get_opt_fn): |
| 48 | logger.info("Building graph for a single training tower ...") |
| 49 | with TrainTowerContext(''): |
| 50 | grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)() |
| 51 | opt = get_opt_fn() |
| 52 | self.train_op = opt.apply_gradients(grads, name='train_op') |
| 53 | return [] |
| 54 | |
| 55 | |
| 56 | class NoOpTrainer(SimpleTrainer): |
no outgoing calls
no test coverage detected