(self, shape, device=None, dtype=None, layout=None, generator=None)
| 1777 | self.sampler_noises = noises |
| 1778 | |
| 1779 | def randn(self, shape, device=None, dtype=None, layout=None, generator=None): |
| 1780 | # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) |
| 1781 | if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): |
| 1782 | noise = self.sampler_noises[self.sampler_noise_index] |
| 1783 | if shape != noise.shape: |
| 1784 | noise = None |
| 1785 | else: |
| 1786 | noise = None |
| 1787 | |
| 1788 | if noise == None: |
| 1789 | logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") |
| 1790 | noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) |
| 1791 | |
| 1792 | self.sampler_noise_index += 1 |
| 1793 | return noise |
| 1794 | |
| 1795 | class TorchRandReplacer: |
| 1796 | def __init__(self, noise_manager): |
no outgoing calls