Run training process.
()
| 242 | |
| 243 | |
| 244 | def main(): |
| 245 | """Run training process.""" |
| 246 | parser = argparse.ArgumentParser( |
| 247 | description="Train ParallelWaveGan (See detail in tensorflow_tts/examples/parallel_wavegan/train_parallel_wavegan.py)" |
| 248 | ) |
| 249 | parser.add_argument( |
| 250 | "--train-dir", |
| 251 | default=None, |
| 252 | type=str, |
| 253 | help="directory including training data. ", |
| 254 | ) |
| 255 | parser.add_argument( |
| 256 | "--dev-dir", |
| 257 | default=None, |
| 258 | type=str, |
| 259 | help="directory including development data. ", |
| 260 | ) |
| 261 | parser.add_argument( |
| 262 | "--use-norm", default=1, type=int, help="use norm mels for training or raw." |
| 263 | ) |
| 264 | parser.add_argument( |
| 265 | "--outdir", type=str, required=True, help="directory to save checkpoints." |
| 266 | ) |
| 267 | parser.add_argument( |
| 268 | "--config", type=str, required=True, help="yaml format configuration file." |
| 269 | ) |
| 270 | parser.add_argument( |
| 271 | "--resume", |
| 272 | default="", |
| 273 | type=str, |
| 274 | nargs="?", |
| 275 | help='checkpoint file path to resume training. (default="")', |
| 276 | ) |
| 277 | parser.add_argument( |
| 278 | "--verbose", |
| 279 | type=int, |
| 280 | default=1, |
| 281 | help="logging level. higher is more logging. (default=1)", |
| 282 | ) |
| 283 | parser.add_argument( |
| 284 | "--generator_mixed_precision", |
| 285 | default=0, |
| 286 | type=int, |
| 287 | help="using mixed precision for generator or not.", |
| 288 | ) |
| 289 | parser.add_argument( |
| 290 | "--discriminator_mixed_precision", |
| 291 | default=0, |
| 292 | type=int, |
| 293 | help="using mixed precision for discriminator or not.", |
| 294 | ) |
| 295 | args = parser.parse_args() |
| 296 | |
| 297 | # return strategy |
| 298 | STRATEGY = return_strategy() |
| 299 | |
| 300 | # set mixed precision config |
| 301 | if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1: |
no test coverage detected