MCPcopy
hub / github.com/Yuanshi9815/OminiControl / training_step

Method training_step

omini/train_flux/trainer.py:139–252  ·  view source on GitHub ↗
(self, batch, batch_idx)

Source from the content-addressed store, hash-verified

137 return optimizer
138
139 def training_step(self, batch, batch_idx):
140 imgs, prompts = batch["image"], batch["description"]
141 image_latent_mask = batch.get("image_latent_mask", None)
142
143 # Get the conditions and position deltas from the batch
144 conditions, position_deltas, position_scales, latent_masks = [], [], [], []
145 for i in range(1000):
146 if f"condition_{i}" not in batch:
147 break
148 conditions.append(batch[f"condition_{i}"])
149 position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
150 position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
151 latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
152
153 # Prepare inputs
154 with torch.no_grad():
155 # Prepare image input
156 x_0, img_ids = encode_images(self.flux_pipe, imgs)
157
158 # Prepare text input
159 (
160 prompt_embeds,
161 pooled_prompt_embeds,
162 text_ids,
163 ) = self.flux_pipe.encode_prompt(
164 prompt=prompts,
165 prompt_2=None,
166 prompt_embeds=None,
167 pooled_prompt_embeds=None,
168 device=self.flux_pipe.device,
169 num_images_per_prompt=1,
170 max_sequence_length=self.model_config.get("max_sequence_length", 512),
171 lora_scale=None,
172 )
173
174 # Prepare t and x_t
175 t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
176 x_1 = torch.randn_like(x_0).to(self.device)
177 t_ = t.unsqueeze(1).unsqueeze(1)
178 x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
179 if image_latent_mask is not None:
180 x_0 = x_0[:, image_latent_mask[0]]
181 x_1 = x_1[:, image_latent_mask[0]]
182 x_t = x_t[:, image_latent_mask[0]]
183 img_ids = img_ids[image_latent_mask[0]]
184
185 # Prepare conditions
186 condition_latents, condition_ids = [], []
187 for cond, p_delta, p_scale, latent_mask in zip(
188 conditions, position_deltas, position_scales, latent_masks
189 ):
190 # Prepare conditions
191 c_latents, c_ids = encode_images(self.flux_pipe, cond)
192 # Scale the position (see OminiConrtol2)
193 if p_scale != 1.0:
194 scale_bias = (p_scale - 1.0) / 2
195 c_ids[:, 1:] *= p_scale
196 c_ids[:, 1:] += scale_bias

Callers

nothing calls this directly

Calls 2

encode_imagesFunction · 0.85
transformer_forwardFunction · 0.85

Tested by

no test coverage detected