A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU.
(
shape: tuple | list,
generator: list["torch.Generator"] | "torch.Generator" | None = None,
device: str | "torch.device" | None = None,
dtype: "torch.dtype" | None = None,
layout: "torch.layout" | None = None,
)
| 150 | |
| 151 | |
| 152 | def randn_tensor( |
| 153 | shape: tuple | list, |
| 154 | generator: list["torch.Generator"] | "torch.Generator" | None = None, |
| 155 | device: str | "torch.device" | None = None, |
| 156 | dtype: "torch.dtype" | None = None, |
| 157 | layout: "torch.layout" | None = None, |
| 158 | ): |
| 159 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
| 160 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
| 161 | is always created on the CPU. |
| 162 | """ |
| 163 | # device on which tensor is created defaults to device |
| 164 | if isinstance(device, str): |
| 165 | device = torch.device(device) |
| 166 | rand_device = device |
| 167 | batch_size = shape[0] |
| 168 | |
| 169 | layout = layout or torch.strided |
| 170 | device = device or torch.device("cpu") |
| 171 | |
| 172 | if generator is not None: |
| 173 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
| 174 | if gen_device_type != device.type and gen_device_type == "cpu": |
| 175 | rand_device = "cpu" |
| 176 | if device != "mps": |
| 177 | logger.info( |
| 178 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
| 179 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
| 180 | f" slightly speed up this function by passing a generator that was created on the {device} device." |
| 181 | ) |
| 182 | elif gen_device_type != device.type and gen_device_type == "cuda": |
| 183 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
| 184 | |
| 185 | # make sure generator list of length 1 is treated like a non-list |
| 186 | if isinstance(generator, list) and len(generator) == 1: |
| 187 | generator = generator[0] |
| 188 | |
| 189 | if isinstance(generator, list): |
| 190 | shape = (1,) + shape[1:] |
| 191 | latents = [ |
| 192 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) |
| 193 | for i in range(batch_size) |
| 194 | ] |
| 195 | latents = torch.cat(latents, dim=0).to(device) |
| 196 | else: |
| 197 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
| 198 | |
| 199 | return latents |
| 200 | |
| 201 | |
| 202 | def is_compiled_module(module) -> bool: |
searching dependent graphs…