(
self,
batch_size: int,
latent_height: int,
latent_width: int,
text_len: int,
device: str = "cuda",
dtype: torch.dtype = torch.float32,
)
| 84 | return dyn_axes |
| 85 | |
| 86 | def get_sample_input( |
| 87 | self, |
| 88 | batch_size: int, |
| 89 | latent_height: int, |
| 90 | latent_width: int, |
| 91 | text_len: int, |
| 92 | device: str = "cuda", |
| 93 | dtype: torch.dtype = torch.float32, |
| 94 | ) -> Tuple[torch.Tensor]: |
| 95 | return ( |
| 96 | torch.randn( |
| 97 | batch_size, |
| 98 | self.in_channels, |
| 99 | latent_height, |
| 100 | latent_width, |
| 101 | dtype=dtype, |
| 102 | device=device, |
| 103 | ), |
| 104 | torch.randn(batch_size, dtype=dtype, device=device), |
| 105 | torch.randn( |
| 106 | batch_size, |
| 107 | text_len, |
| 108 | self.embedding_dim, |
| 109 | dtype=dtype, |
| 110 | device=device, |
| 111 | ), |
| 112 | torch.randn(batch_size, self.num_xl_classes, dtype=dtype, device=device) |
| 113 | if self.is_xl |
| 114 | else None, |
| 115 | ) |
| 116 | |
| 117 | def get_input_profile(self, profile: ProfileSettings) -> dict: |
| 118 | min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim() |
no outgoing calls
no test coverage detected