MCPcopy Index your code
hub / github.com/tdrussell/diffusion-pipe / forward

Method forward

models/hidream.py:364–411  ·  view source on GitHub ↗
(self, inputs)

Source from the content-addressed store, hash-verified

362
363 @torch.autocast('cuda', dtype=AUTOCAST_DTYPE)
364 def forward(self, inputs):
365 hidden_states, img_ids, timesteps, pooled_embeds, t5_prompt_embeds, llama3_prompt_embeds = inputs
366
367 batch_size = hidden_states.shape[0]
368 hidden_states_type = hidden_states.dtype
369
370 timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
371 timesteps = self.t_embedder(timesteps, hidden_states_type)
372 p_embedder = self.p_embedder(pooled_embeds)
373 adaln_input = timesteps + p_embedder
374
375 hidden_states = self.x_embedder(hidden_states)
376
377 T5_encoder_hidden_states = t5_prompt_embeds
378 encoder_hidden_states = llama3_prompt_embeds
379 encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
380
381 assert self.caption_projection is not None
382 if self.caption_projection is not None:
383 new_encoder_hidden_states = []
384 for i, enc_hidden_state in enumerate(encoder_hidden_states):
385 enc_hidden_state = self.caption_projection[i](enc_hidden_state)
386 enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
387 new_encoder_hidden_states.append(enc_hidden_state)
388 encoder_hidden_states = new_encoder_hidden_states
389 T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
390 T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
391 encoder_hidden_states.append(T5_encoder_hidden_states)
392
393 txt_ids = torch.zeros(
394 batch_size,
395 encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
396 3,
397 device=img_ids.device, dtype=img_ids.dtype
398 )
399 ids = torch.cat((img_ids, txt_ids), dim=1)
400 rope = self.pe_embedder(ids)
401
402 initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
403 llama_encoder_hidden_states = torch.stack(encoder_hidden_states[:-1], dim=0)
404
405 # With nf4 quantization, tensors can end up float32, which breaks flash attention later, so we cast it here.
406 hidden_states = hidden_states.to(AUTOCAST_DTYPE)
407 initial_encoder_hidden_states = initial_encoder_hidden_states.to(AUTOCAST_DTYPE)
408 llama_encoder_hidden_states = llama_encoder_hidden_states.to(AUTOCAST_DTYPE)
409 adaln_input = adaln_input.to(AUTOCAST_DTYPE)
410
411 return make_contiguous(hidden_states, initial_encoder_hidden_states, llama_encoder_hidden_states, adaln_input, rope)
412
413
414class TransformerWrapper(nn.Module):

Callers

nothing calls this directly

Calls 2

make_contiguousFunction · 0.90
toMethod · 0.45

Tested by

no test coverage detected