MCPcopy
hub / github.com/Yuanshi9815/OminiControl / lora_forward

Function lora_forward

omini/pipeline/flux_omini.py:144–184  ·  view source on GitHub ↗

Apply a single, explicitly-selected LoRA adapter to ``module`` at scale 1.0. This is a fast, allocation-free replacement for the previous ``specify_lora`` context manager, which mutated ``module.scaling`` on every call. Semantics are preserved exactly: * If ``module`` is not a

(module, x: torch.Tensor, adapter)

Source from the content-addressed store, hash-verified

142
143
144def lora_forward(module, x: torch.Tensor, adapter) -> torch.Tensor:
145 """
146 Apply a single, explicitly-selected LoRA adapter to ``module`` at scale 1.0.
147
148 This is a fast, allocation-free replacement for the previous
149 ``specify_lora`` context manager, which mutated ``module.scaling`` on every
150 call. Semantics are preserved exactly:
151
152 * If ``module`` is not a PEFT ``BaseTunerLayer`` it is called directly.
153 * The base (non-LoRA) projection is always applied.
154 * When ``adapter`` is not None and is an *active* adapter with weights on
155 this module, its LoRA delta is added with scale hardcoded to ``1.0``
156 (matching ``specify_lora`` setting ``scaling[adapter] = 1``). All other
157 adapters contribute nothing (matching ``scaling[other] = 0``).
158
159 LoRA dropout is applied via the adapter's own dropout module, which is an
160 Identity (or ``p == 0``) at inference/eval, so inference stays bit-identical
161 while training matches PEFT's semantics.
162 """
163 if not isinstance(module, BaseTunerLayer):
164 return module(x)
165
166 result = module.base_layer(x)
167
168 # No adapter requested, adapters disabled, or weights already merged into
169 # the base layer -> nothing more to add (mirrors PEFT's forward).
170 if adapter is None or module.disable_adapters or module.merged:
171 return result
172
173 # Only an *active* adapter with LoRA weights on this module contributes,
174 # exactly as in the original specify_lora + PEFT forward path.
175 if adapter not in module.active_adapters or adapter not in module.lora_A:
176 return result
177
178 torch_result_dtype = result.dtype
179 lora_A = module.lora_A[adapter]
180 lora_B = module.lora_B[adapter]
181 dropout = module.lora_dropout[adapter] # Identity at eval / when p == 0
182 x = x.to(lora_A.weight.dtype)
183 result = result + lora_B(lora_A(dropout(x))) # scale hardcoded to 1.0
184 return result.to(torch_result_dtype)
185
186
187def _adanorm_zero_forward(norm, x, emb, adapter):

Callers 6

_adanorm_zero_forwardFunction · 0.85
_feedforward_forwardFunction · 0.85
attn_forwardFunction · 0.85
single_block_forwardFunction · 0.85
transformer_forwardFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected