(self, shape, device=None, dtype=None, layout=None, generator=None)
| 2501 | self.sampler_noises = noises |
| 2502 | |
| 2503 | def randn(self, shape, device=None, dtype=None, layout=None, generator=None): |
| 2504 | # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") |
| 2505 | if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): |
| 2506 | noise = self.sampler_noises[self.sampler_noise_index] |
| 2507 | if shape != noise.shape: |
| 2508 | noise = None |
| 2509 | else: |
| 2510 | noise = None |
| 2511 | |
| 2512 | if noise == None: |
| 2513 | logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") |
| 2514 | noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) |
| 2515 | |
| 2516 | self.sampler_noise_index += 1 |
| 2517 | return noise |
| 2518 | |
| 2519 | class TorchRandReplacer: |
| 2520 | def __init__(self, noise_manager): |
no outgoing calls
no test coverage detected