(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
)
| 247 | return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] |
| 248 | |
| 249 | def concat_encodings( |
| 250 | self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor |
| 251 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 252 | lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) |
| 253 | if t5_out is None: |
| 254 | t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) |
| 255 | return torch.cat([lg_out, t5_out], dim=-2), lg_pooled |
| 256 | |
| 257 | |
| 258 | class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): |
no outgoing calls
no test coverage detected