(self, inputs)
| 362 | |
| 363 | @torch.autocast('cuda', dtype=AUTOCAST_DTYPE) |
| 364 | def forward(self, inputs): |
| 365 | hidden_states, img_ids, timesteps, pooled_embeds, t5_prompt_embeds, llama3_prompt_embeds = inputs |
| 366 | |
| 367 | batch_size = hidden_states.shape[0] |
| 368 | hidden_states_type = hidden_states.dtype |
| 369 | |
| 370 | timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) |
| 371 | timesteps = self.t_embedder(timesteps, hidden_states_type) |
| 372 | p_embedder = self.p_embedder(pooled_embeds) |
| 373 | adaln_input = timesteps + p_embedder |
| 374 | |
| 375 | hidden_states = self.x_embedder(hidden_states) |
| 376 | |
| 377 | T5_encoder_hidden_states = t5_prompt_embeds |
| 378 | encoder_hidden_states = llama3_prompt_embeds |
| 379 | encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] |
| 380 | |
| 381 | assert self.caption_projection is not None |
| 382 | if self.caption_projection is not None: |
| 383 | new_encoder_hidden_states = [] |
| 384 | for i, enc_hidden_state in enumerate(encoder_hidden_states): |
| 385 | enc_hidden_state = self.caption_projection[i](enc_hidden_state) |
| 386 | enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) |
| 387 | new_encoder_hidden_states.append(enc_hidden_state) |
| 388 | encoder_hidden_states = new_encoder_hidden_states |
| 389 | T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) |
| 390 | T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
| 391 | encoder_hidden_states.append(T5_encoder_hidden_states) |
| 392 | |
| 393 | txt_ids = torch.zeros( |
| 394 | batch_size, |
| 395 | encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], |
| 396 | 3, |
| 397 | device=img_ids.device, dtype=img_ids.dtype |
| 398 | ) |
| 399 | ids = torch.cat((img_ids, txt_ids), dim=1) |
| 400 | rope = self.pe_embedder(ids) |
| 401 | |
| 402 | initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) |
| 403 | llama_encoder_hidden_states = torch.stack(encoder_hidden_states[:-1], dim=0) |
| 404 | |
| 405 | # With nf4 quantization, tensors can end up float32, which breaks flash attention later, so we cast it here. |
| 406 | hidden_states = hidden_states.to(AUTOCAST_DTYPE) |
| 407 | initial_encoder_hidden_states = initial_encoder_hidden_states.to(AUTOCAST_DTYPE) |
| 408 | llama_encoder_hidden_states = llama_encoder_hidden_states.to(AUTOCAST_DTYPE) |
| 409 | adaln_input = adaln_input.to(AUTOCAST_DTYPE) |
| 410 | |
| 411 | return make_contiguous(hidden_states, initial_encoder_hidden_states, llama_encoder_hidden_states, adaln_input, rope) |
| 412 | |
| 413 | |
| 414 | class TransformerWrapper(nn.Module): |
nothing calls this directly
no test coverage detected