| 38 | # Comes from |
| 39 | # https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9 |
| 40 | def extract_lora(diff, rank): |
| 41 | # Important to use CUDA otherwise, very slow! |
| 42 | if torch.cuda.is_available(): |
| 43 | diff = diff.to("cuda") |
| 44 | |
| 45 | is_conv2d = len(diff.shape) == 4 |
| 46 | kernel_size = None if not is_conv2d else diff.size()[2:4] |
| 47 | is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1) |
| 48 | out_dim, in_dim = diff.size()[0:2] |
| 49 | rank = min(rank, in_dim, out_dim) |
| 50 | |
| 51 | if is_conv2d: |
| 52 | if is_conv2d_3x3: |
| 53 | diff = diff.flatten(start_dim=1) |
| 54 | else: |
| 55 | diff = diff.squeeze() |
| 56 | |
| 57 | U, S, Vh = torch.linalg.svd(diff.float()) |
| 58 | U = U[:, :rank] |
| 59 | S = S[:rank] |
| 60 | U = U @ torch.diag(S) |
| 61 | Vh = Vh[:rank, :] |
| 62 | |
| 63 | dist = torch.cat([U.flatten(), Vh.flatten()]) |
| 64 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
| 65 | low_val = -hi_val |
| 66 | |
| 67 | U = U.clamp(low_val, hi_val) |
| 68 | Vh = Vh.clamp(low_val, hi_val) |
| 69 | if is_conv2d: |
| 70 | U = U.reshape(out_dim, rank, 1, 1) |
| 71 | Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) |
| 72 | return (U.cpu(), Vh.cpu()) |
| 73 | |
| 74 | |
| 75 | def parse_args(): |