Method
cond_fn
(
self,
latents,
timestep,
index,
text_embeddings,
noise_pred_original,
guide_embeddings_clip,
clip_guidance_scale,
num_cutouts,
use_cutouts=True,
)
Source from the content-addressed store, hash-verified
| 1584 | |
| 1585 | # バッチを分解して1件ずつ処理する |
| 1586 | def cond_fn( |
| 1587 | self, |
| 1588 | latents, |
| 1589 | timestep, |
| 1590 | index, |
| 1591 | text_embeddings, |
| 1592 | noise_pred_original, |
| 1593 | guide_embeddings_clip, |
| 1594 | clip_guidance_scale, |
| 1595 | num_cutouts, |
| 1596 | use_cutouts=True, |
| 1597 | ): |
| 1598 | if len(latents) == 1: |
| 1599 | return self.cond_fn1( |
| 1600 | latents, |
| 1601 | timestep, |
| 1602 | index, |
| 1603 | text_embeddings, |
| 1604 | noise_pred_original, |
| 1605 | guide_embeddings_clip, |
| 1606 | clip_guidance_scale, |
| 1607 | num_cutouts, |
| 1608 | use_cutouts, |
| 1609 | ) |
| 1610 | |
| 1611 | noise_pred = [] |
| 1612 | cond_latents = [] |
| 1613 | for i in range(len(latents)): |
| 1614 | lat1 = latents[i].unsqueeze(0) |
| 1615 | tem1 = text_embeddings[i].unsqueeze(0) |
| 1616 | npo1 = noise_pred_original[i].unsqueeze(0) |
| 1617 | gem1 = guide_embeddings_clip[i].unsqueeze(0) |
| 1618 | npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) |
| 1619 | noise_pred.append(npr1) |
| 1620 | cond_latents.append(cla1) |
| 1621 | |
| 1622 | noise_pred = torch.cat(noise_pred) |
| 1623 | cond_latents = torch.cat(cond_latents) |
| 1624 | return noise_pred, cond_latents |
| 1625 | |
| 1626 | @torch.enable_grad() |
| 1627 | def cond_fn1( |
Tested by
no test coverage detected