Initialize collater (mapping function) for Tensorflow Audio-Mel Dataset. Args: batch_max_steps (int): The maximum length of input signal in batch. hop_size (int): Hop size of auxiliary features.
(
items,
batch_max_steps=tf.constant(8192, dtype=tf.int32),
hop_size=tf.constant(256, dtype=tf.int32),
)
| 240 | |
| 241 | |
| 242 | def collater( |
| 243 | items, |
| 244 | batch_max_steps=tf.constant(8192, dtype=tf.int32), |
| 245 | hop_size=tf.constant(256, dtype=tf.int32), |
| 246 | ): |
| 247 | """Initialize collater (mapping function) for Tensorflow Audio-Mel Dataset. |
| 248 | |
| 249 | Args: |
| 250 | batch_max_steps (int): The maximum length of input signal in batch. |
| 251 | hop_size (int): Hop size of auxiliary features. |
| 252 | |
| 253 | """ |
| 254 | audio, mel = items["audios"], items["mels"] |
| 255 | |
| 256 | if batch_max_steps is None: |
| 257 | batch_max_steps = (tf.shape(audio)[0] // hop_size) * hop_size |
| 258 | |
| 259 | batch_max_frames = batch_max_steps // hop_size |
| 260 | if len(audio) < len(mel) * hop_size: |
| 261 | audio = tf.pad(audio, [[0, len(mel) * hop_size - len(audio)]]) |
| 262 | |
| 263 | if len(mel) > batch_max_frames: |
| 264 | # randomly pickup with the batch_max_steps length of the part |
| 265 | interval_start = 0 |
| 266 | interval_end = len(mel) - batch_max_frames |
| 267 | start_frame = tf.random.uniform( |
| 268 | shape=[], minval=interval_start, maxval=interval_end, dtype=tf.int32 |
| 269 | ) |
| 270 | start_step = start_frame * hop_size |
| 271 | audio = audio[start_step : start_step + batch_max_steps] |
| 272 | mel = mel[start_frame : start_frame + batch_max_frames, :] |
| 273 | else: |
| 274 | audio = tf.pad(audio, [[0, batch_max_steps - len(audio)]]) |
| 275 | mel = tf.pad(mel, [[0, batch_max_frames - len(mel)], [0, 0]]) |
| 276 | |
| 277 | items = { |
| 278 | "utt_ids": items["utt_ids"], |
| 279 | "audios": audio, |
| 280 | "mels": mel, |
| 281 | "mel_lengths": len(mel), |
| 282 | "audio_lengths": len(audio), |
| 283 | } |
| 284 | |
| 285 | return items |
| 286 | |
| 287 | |
| 288 | def main(): |