(self, batch, batch_idx)
| 137 | return optimizer |
| 138 | |
| 139 | def training_step(self, batch, batch_idx): |
| 140 | imgs, prompts = batch["image"], batch["description"] |
| 141 | image_latent_mask = batch.get("image_latent_mask", None) |
| 142 | |
| 143 | # Get the conditions and position deltas from the batch |
| 144 | conditions, position_deltas, position_scales, latent_masks = [], [], [], [] |
| 145 | for i in range(1000): |
| 146 | if f"condition_{i}" not in batch: |
| 147 | break |
| 148 | conditions.append(batch[f"condition_{i}"]) |
| 149 | position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]])) |
| 150 | position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0]) |
| 151 | latent_masks.append(batch.get(f"condition_latent_mask_{i}", None)) |
| 152 | |
| 153 | # Prepare inputs |
| 154 | with torch.no_grad(): |
| 155 | # Prepare image input |
| 156 | x_0, img_ids = encode_images(self.flux_pipe, imgs) |
| 157 | |
| 158 | # Prepare text input |
| 159 | ( |
| 160 | prompt_embeds, |
| 161 | pooled_prompt_embeds, |
| 162 | text_ids, |
| 163 | ) = self.flux_pipe.encode_prompt( |
| 164 | prompt=prompts, |
| 165 | prompt_2=None, |
| 166 | prompt_embeds=None, |
| 167 | pooled_prompt_embeds=None, |
| 168 | device=self.flux_pipe.device, |
| 169 | num_images_per_prompt=1, |
| 170 | max_sequence_length=self.model_config.get("max_sequence_length", 512), |
| 171 | lora_scale=None, |
| 172 | ) |
| 173 | |
| 174 | # Prepare t and x_t |
| 175 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) |
| 176 | x_1 = torch.randn_like(x_0).to(self.device) |
| 177 | t_ = t.unsqueeze(1).unsqueeze(1) |
| 178 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) |
| 179 | if image_latent_mask is not None: |
| 180 | x_0 = x_0[:, image_latent_mask[0]] |
| 181 | x_1 = x_1[:, image_latent_mask[0]] |
| 182 | x_t = x_t[:, image_latent_mask[0]] |
| 183 | img_ids = img_ids[image_latent_mask[0]] |
| 184 | |
| 185 | # Prepare conditions |
| 186 | condition_latents, condition_ids = [], [] |
| 187 | for cond, p_delta, p_scale, latent_mask in zip( |
| 188 | conditions, position_deltas, position_scales, latent_masks |
| 189 | ): |
| 190 | # Prepare conditions |
| 191 | c_latents, c_ids = encode_images(self.flux_pipe, cond) |
| 192 | # Scale the position (see OminiConrtol2) |
| 193 | if p_scale != 1.0: |
| 194 | scale_bias = (p_scale - 1.0) / 2 |
| 195 | c_ids[:, 1:] *= p_scale |
| 196 | c_ids[:, 1:] += scale_bias |
nothing calls this directly
no test coverage detected