| 381 | |
| 382 | |
| 383 | class GradualLatent: |
| 384 | def __init__( |
| 385 | self, |
| 386 | ratio, |
| 387 | start_timesteps, |
| 388 | every_n_steps, |
| 389 | ratio_step, |
| 390 | s_noise=1.0, |
| 391 | gaussian_blur_ksize=None, |
| 392 | gaussian_blur_sigma=0.5, |
| 393 | gaussian_blur_strength=0.5, |
| 394 | unsharp_target_x=True, |
| 395 | ): |
| 396 | self.ratio = ratio |
| 397 | self.start_timesteps = start_timesteps |
| 398 | self.every_n_steps = every_n_steps |
| 399 | self.ratio_step = ratio_step |
| 400 | self.s_noise = s_noise |
| 401 | self.gaussian_blur_ksize = gaussian_blur_ksize |
| 402 | self.gaussian_blur_sigma = gaussian_blur_sigma |
| 403 | self.gaussian_blur_strength = gaussian_blur_strength |
| 404 | self.unsharp_target_x = unsharp_target_x |
| 405 | |
| 406 | def __str__(self) -> str: |
| 407 | return ( |
| 408 | f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " |
| 409 | + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " |
| 410 | + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " |
| 411 | + f"unsharp_target_x={self.unsharp_target_x})" |
| 412 | ) |
| 413 | |
| 414 | def apply_unshark_mask(self, x: torch.Tensor): |
| 415 | if self.gaussian_blur_ksize is None: |
| 416 | return x |
| 417 | blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) |
| 418 | # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) |
| 419 | mask = (x - blurred) * self.gaussian_blur_strength |
| 420 | sharpened = x + mask |
| 421 | return sharpened |
| 422 | |
| 423 | def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): |
| 424 | org_dtype = x.dtype |
| 425 | if org_dtype == torch.bfloat16: |
| 426 | x = x.float() |
| 427 | |
| 428 | x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) |
| 429 | |
| 430 | # apply unsharp mask / アンシャープマスクを適用する |
| 431 | if unsharp and self.gaussian_blur_ksize: |
| 432 | x = self.apply_unshark_mask(x) |
| 433 | |
| 434 | return x |
| 435 | |
| 436 | |
| 437 | class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): |