MCPcopy
hub / github.com/lucidrains/DALLE-pytorch / forward_with_cond_scale

Method forward_with_cond_scale

dalle_pytorch/dalle_pytorch.py:564–574  ·  view source on GitHub ↗
(self, *args, cond_scale = 1, cache = None, **kwargs)

Source from the content-addressed store, hash-verified

562 return images
563
564 def forward_with_cond_scale(self, *args, cond_scale = 1, cache = None, **kwargs):
565 if cond_scale == 1:
566 return self(*args, **kwargs)
567
568 prev_cache = cache.copy() if exists(cache) else None
569 logits = self(*args, cache = cache, **kwargs)
570
571 # discovery by Katherine Crowson
572 # https://twitter.com/RiversHaveWings/status/1478093658716966912
573 null_cond_logits = self(*args, null_cond_prob = 1., cache = prev_cache, **kwargs)
574 return null_cond_logits + (logits - null_cond_logits) * cond_scale
575
576 def forward(
577 self,

Callers 1

generate_imagesMethod · 0.95

Calls 1

existsFunction · 0.70

Tested by

no test coverage detected