sub_batch_size: How many images to decode in a single pass. See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more context.
(self, samples, images, sd_version: str, sub_batch_size: int)
| 135 | self.vae_transparent_decoder = {} |
| 136 | |
| 137 | def decode(self, samples, images, sd_version: str, sub_batch_size: int): |
| 138 | """ |
| 139 | sub_batch_size: How many images to decode in a single pass. |
| 140 | See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more |
| 141 | context. |
| 142 | """ |
| 143 | sd_version = StableDiffusionVersion(sd_version) |
| 144 | if sd_version == StableDiffusionVersion.SD1x: |
| 145 | url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors" |
| 146 | file_name = "layer_sd15_vae_transparent_decoder.safetensors" |
| 147 | elif sd_version == StableDiffusionVersion.SDXL: |
| 148 | url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors" |
| 149 | file_name = "vae_transparent_decoder.safetensors" |
| 150 | |
| 151 | if not self.vae_transparent_decoder.get(sd_version): |
| 152 | model_path = load_file_from_url( |
| 153 | url=url, model_dir=layer_model_root, file_name=file_name |
| 154 | ) |
| 155 | self.vae_transparent_decoder[sd_version] = TransparentVAEDecoder( |
| 156 | load_torch_file(model_path), |
| 157 | device=comfy.model_management.get_torch_device(), |
| 158 | dtype=( |
| 159 | torch.float16 |
| 160 | if comfy.model_management.should_use_fp16() |
| 161 | else torch.float32 |
| 162 | ), |
| 163 | ) |
| 164 | pixel = images.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W] |
| 165 | |
| 166 | # Decoder requires dimension to be 64-aligned. |
| 167 | B, C, H, W = pixel.shape |
| 168 | assert H % 64 == 0, f"Height({H}) is not multiple of 64." |
| 169 | assert W % 64 == 0, f"Height({W}) is not multiple of 64." |
| 170 | |
| 171 | decoded = [] |
| 172 | for start_idx in range(0, samples["samples"].shape[0], sub_batch_size): |
| 173 | decoded.append( |
| 174 | self.vae_transparent_decoder[sd_version].decode_pixel( |
| 175 | pixel[start_idx : start_idx + sub_batch_size], |
| 176 | samples["samples"][start_idx : start_idx + sub_batch_size], |
| 177 | ) |
| 178 | ) |
| 179 | pixel_with_alpha = torch.cat(decoded, dim=0) |
| 180 | |
| 181 | # [B, C, H, W] => [B, H, W, C] |
| 182 | pixel_with_alpha = pixel_with_alpha.movedim(1, -1) |
| 183 | image = pixel_with_alpha[..., 1:] |
| 184 | alpha = pixel_with_alpha[..., 0] |
| 185 | return (image, alpha) |
| 186 | |
| 187 | |
| 188 | class LayeredDiffusionDecodeRGBA(LayeredDiffusionDecode): |
no test coverage detected