MCPcopy Index your code
hub / github.com/huggingface/diffusers / extract_lora

Function extract_lora

scripts/extract_lora_from_model.py:40–72  ·  view source on GitHub ↗
(diff, rank)

Source from the content-addressed store, hash-verified

38# Comes from
39# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
40def extract_lora(diff, rank):
41 # Important to use CUDA otherwise, very slow!
42 if torch.cuda.is_available():
43 diff = diff.to("cuda")
44
45 is_conv2d = len(diff.shape) == 4
46 kernel_size = None if not is_conv2d else diff.size()[2:4]
47 is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
48 out_dim, in_dim = diff.size()[0:2]
49 rank = min(rank, in_dim, out_dim)
50
51 if is_conv2d:
52 if is_conv2d_3x3:
53 diff = diff.flatten(start_dim=1)
54 else:
55 diff = diff.squeeze()
56
57 U, S, Vh = torch.linalg.svd(diff.float())
58 U = U[:, :rank]
59 S = S[:rank]
60 U = U @ torch.diag(S)
61 Vh = Vh[:rank, :]
62
63 dist = torch.cat([U.flatten(), Vh.flatten()])
64 hi_val = torch.quantile(dist, CLAMP_QUANTILE)
65 low_val = -hi_val
66
67 U = U.clamp(low_val, hi_val)
68 Vh = Vh.clamp(low_val, hi_val)
69 if is_conv2d:
70 U = U.reshape(out_dim, rank, 1, 1)
71 Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
72 return (U.cpu(), Vh.cpu())
73
74
75def parse_args():

Callers 1

mainFunction · 0.85

Calls 2

floatMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…