MCPcopy
hub / github.com/Audio-AGI/AudioSep / training_step

Method training_step

models/audiosep.py:52–113  ·  view source on GitHub ↗

r"""Forward a mini-batch data to model, calculate loss function, and train for one step. A mini-batch data is evenly distributed to multiple devices (if there are) for parallel training. Args: batch_data_dict: e.g. 'audio_text': {

(self, batch_data_dict, batch_idx)

Source from the content-addressed store, hash-verified

50 pass
51
52 def training_step(self, batch_data_dict, batch_idx):
53 r"""Forward a mini-batch data to model, calculate loss function, and
54 train for one step. A mini-batch data is evenly distributed to multiple
55 devices (if there are) for parallel training.
56
57 Args:
58 batch_data_dict: e.g.
59 'audio_text': {
60 'text': ['a sound of dog', ...]
61 'waveform': (batch_size, 1, samples)
62 }
63 batch_idx: int
64
65 Returns:
66 loss: float, loss function of this mini-batch
67 """
68 # [important] fix random seeds across devices
69 random.seed(batch_idx)
70
71 batch_audio_text_dict = batch_data_dict['audio_text']
72
73 batch_text = batch_audio_text_dict['text']
74 batch_audio = batch_audio_text_dict['waveform']
75 device = batch_audio.device
76
77 mixtures, segments = self.waveform_mixer(
78 waveforms=batch_audio
79 )
80
81 # calculate text embed for audio-text data
82 if self.query_encoder_type == 'CLAP':
83 conditions = self.query_encoder.get_query_embed(
84 modality='hybird',
85 text=batch_text,
86 audio=segments.squeeze(1),
87 use_text_ratio=self.use_text_ratio,
88 )
89
90 input_dict = {
91 'mixture': mixtures[:, None, :].squeeze(1),
92 'condition': conditions,
93 }
94
95 target_dict = {
96 'segment': segments.squeeze(1),
97 }
98
99 self.ss_model.train()
100 sep_segment = self.ss_model(input_dict)['waveform']
101 sep_segment = sep_segment.squeeze()
102 # (batch_size, 1, segment_samples)
103
104 output_dict = {
105 'segment': sep_segment,
106 }
107
108 # Calculate loss.
109 loss = self.loss_function(output_dict, target_dict)

Callers

nothing calls this directly

Calls 1

get_query_embedMethod · 0.80

Tested by

no test coverage detected