()
| 186 | |
| 187 | |
| 188 | def main(): |
| 189 | args = get_arguments() |
| 190 | |
| 191 | try: |
| 192 | directories = validate_directories(args) |
| 193 | except ValueError as e: |
| 194 | print("Some arguments are wrong:") |
| 195 | print(str(e)) |
| 196 | return |
| 197 | |
| 198 | logdir = directories['logdir'] |
| 199 | restore_from = directories['restore_from'] |
| 200 | |
| 201 | # Even if we restored the model, we will treat it as new training |
| 202 | # if the trained model is written into an arbitrary location. |
| 203 | is_overwritten_training = logdir != restore_from |
| 204 | |
| 205 | with open(args.wavenet_params, 'r') as f: |
| 206 | wavenet_params = json.load(f) |
| 207 | |
| 208 | # Create coordinator. |
| 209 | coord = tf.train.Coordinator() |
| 210 | |
| 211 | # Load raw waveform from VCTK corpus. |
| 212 | with tf.name_scope('create_inputs'): |
| 213 | # Allow silence trimming to be skipped by specifying a threshold near |
| 214 | # zero. |
| 215 | silence_threshold = args.silence_threshold if args.silence_threshold > \ |
| 216 | EPSILON else None |
| 217 | gc_enabled = args.gc_channels is not None |
| 218 | reader = AudioReader( |
| 219 | args.data_dir, |
| 220 | coord, |
| 221 | sample_rate=wavenet_params['sample_rate'], |
| 222 | gc_enabled=gc_enabled, |
| 223 | receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"], |
| 224 | wavenet_params["dilations"], |
| 225 | wavenet_params["scalar_input"], |
| 226 | wavenet_params["initial_filter_width"]), |
| 227 | sample_size=args.sample_size, |
| 228 | silence_threshold=silence_threshold) |
| 229 | audio_batch = reader.dequeue(args.batch_size) |
| 230 | if gc_enabled: |
| 231 | gc_id_batch = reader.dequeue_gc(args.batch_size) |
| 232 | else: |
| 233 | gc_id_batch = None |
| 234 | |
| 235 | # Create network. |
| 236 | net = WaveNetModel( |
| 237 | batch_size=args.batch_size, |
| 238 | dilations=wavenet_params["dilations"], |
| 239 | filter_width=wavenet_params["filter_width"], |
| 240 | residual_channels=wavenet_params["residual_channels"], |
| 241 | dilation_channels=wavenet_params["dilation_channels"], |
| 242 | skip_channels=wavenet_params["skip_channels"], |
| 243 | quantization_channels=wavenet_params["quantization_channels"], |
| 244 | use_biases=wavenet_params["use_biases"], |
| 245 | scalar_input=wavenet_params["scalar_input"], |
no test coverage detected