(
noise_pred_pos,
noise_pred_neg,
noise_pred_pertubed,
cfg_scale,
stg_scale,
rescale_scale,
)
| 14 | |
| 15 | |
| 16 | def stg( |
| 17 | noise_pred_pos, |
| 18 | noise_pred_neg, |
| 19 | noise_pred_pertubed, |
| 20 | cfg_scale, |
| 21 | stg_scale, |
| 22 | rescale_scale, |
| 23 | ): |
| 24 | noise_pred = ( |
| 25 | noise_pred_pos |
| 26 | + (cfg_scale - 1) * (noise_pred_pos - noise_pred_neg) |
| 27 | + stg_scale * (noise_pred_pos - noise_pred_pertubed) |
| 28 | ) |
| 29 | if rescale_scale != 0: |
| 30 | factor = noise_pred_pos.std() / noise_pred.std() |
| 31 | factor = rescale_scale * factor + (1 - rescale_scale) |
| 32 | noise_pred = noise_pred * factor |
| 33 | return noise_pred |
| 34 | |
| 35 | |
| 36 | class MomentumBuffer: |
no outgoing calls
no test coverage detected