(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
)
| 10 | |
| 11 | class StandardDiffusionLoss(nn.Module): |
| 12 | def __init__( |
| 13 | self, |
| 14 | sigma_sampler_config, |
| 15 | type="l2", |
| 16 | offset_noise_level=0.0, |
| 17 | batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, |
| 18 | ): |
| 19 | super().__init__() |
| 20 | |
| 21 | assert type in ["l2", "l1", "lpips"] |
| 22 | |
| 23 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) |
| 24 | |
| 25 | self.type = type |
| 26 | self.offset_noise_level = offset_noise_level |
| 27 | |
| 28 | if type == "lpips": |
| 29 | self.lpips = LPIPS().eval() |
| 30 | |
| 31 | if not batch2model_keys: |
| 32 | batch2model_keys = [] |
| 33 | |
| 34 | if isinstance(batch2model_keys, str): |
| 35 | batch2model_keys = [batch2model_keys] |
| 36 | |
| 37 | self.batch2model_keys = set(batch2model_keys) |
| 38 | |
| 39 | def __call__(self, network, denoiser, conditioner, input, batch): |
| 40 | cond = conditioner(batch) |
nothing calls this directly
no test coverage detected