(c, p, target_model=None, target_clip=None)
| 21 | @torch.no_grad() |
| 22 | @torch.inference_mode() |
| 23 | def clip_separate_inner(c, p, target_model=None, target_clip=None): |
| 24 | if target_model is None or isinstance(target_model, SDXLRefiner): |
| 25 | c = c[..., -1280:].clone() |
| 26 | elif isinstance(target_model, SDXL): |
| 27 | c = c.clone() |
| 28 | else: |
| 29 | p = None |
| 30 | c = c[..., :768].clone() |
| 31 | |
| 32 | final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm |
| 33 | |
| 34 | final_layer_norm_origin_device = final_layer_norm.weight.device |
| 35 | final_layer_norm_origin_dtype = final_layer_norm.weight.dtype |
| 36 | |
| 37 | c_origin_device = c.device |
| 38 | c_origin_dtype = c.dtype |
| 39 | |
| 40 | final_layer_norm.to(device='cpu', dtype=torch.float32) |
| 41 | c = c.to(device='cpu', dtype=torch.float32) |
| 42 | |
| 43 | c = torch.chunk(c, int(c.size(1)) // 77, 1) |
| 44 | c = [final_layer_norm(ci) for ci in c] |
| 45 | c = torch.cat(c, dim=1) |
| 46 | |
| 47 | final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype) |
| 48 | c = c.to(device=c_origin_device, dtype=c_origin_dtype) |
| 49 | return c, p |
| 50 | |
| 51 | |
| 52 | @torch.no_grad() |
no test coverage detected