()
| 43 | |
| 44 | |
| 45 | def get_config(): |
| 46 | logger.auto_set_dir() |
| 47 | |
| 48 | ds_train = get_data('train') |
| 49 | ds_test = get_data('test') |
| 50 | |
| 51 | return TrainConfig( |
| 52 | model=Model(), |
| 53 | data=QueueInput(ds_train), |
| 54 | callbacks=[ |
| 55 | ModelSaver(), |
| 56 | InferenceRunner(ds_test, [ScalarStats('total_costs')]), |
| 57 | ], |
| 58 | steps_per_epoch=len(ds_train), |
| 59 | max_epoch=100, |
| 60 | ) |
| 61 | |
| 62 | |
| 63 | if __name__ == '__main__': |
no test coverage detected