Continuous Normalizing Flow model
| 29 | return v.view(x.shape[0], 3, 32, 32) # Reshape back to image |
| 30 | |
| 31 | class CNF(nn.Module): |
| 32 | """Continuous Normalizing Flow model""" |
| 33 | def __init__(self, velocity_net, sigma_min=0.1, n_steps=100): |
| 34 | super().__init__() |
| 35 | self.velocity_net = velocity_net |
| 36 | self.sigma_min = sigma_min |
| 37 | self.n_steps = n_steps |
| 38 | |
| 39 | def f_theta(self, t, x_t): |
| 40 | """Compute f_θ(t, x_t) = x_t + (1-t)v_θ(t, x_t)""" |
| 41 | v = self.velocity_net(x_t, t) |
| 42 | t = t.view(-1, 1, 1, 1).expand(-1, 3, 32, 32) |
| 43 | return x_t + (1-t)*v |
| 44 | |
| 45 | def velocity_field(self, t, x_t): |
| 46 | """Compute velocity field v_θ(t, x_t)""" |
| 47 | return self.velocity_net(x_t, t) |
| 48 | |
| 49 | def sample(self, n_samples, device): |
| 50 | """Generate samples using Euler method""" |
| 51 | # Sample from base distribution |
| 52 | x = torch.randn(n_samples, 3, 32, 32).to(device) |
| 53 | |
| 54 | dt = 1.0 / self.n_steps |
| 55 | for i in range(self.n_steps): |
| 56 | t = torch.ones(n_samples, device=device) * (i * dt) |
| 57 | v = self.velocity_field(t, x) |
| 58 | x = x + dt * v |
| 59 | |
| 60 | return x |
| 61 | |
| 62 | def forward(self, x_0, x_1, t): |
| 63 | """Forward pass computing velocity consistency loss""" |
| 64 | dt = 0.01 # Small time step for consistency loss |
| 65 | t_next = torch.clamp(t + dt, max=1.0) |
| 66 | |
| 67 | # Current points |
| 68 | t_exp = t.view(-1, 1, 1, 1) |
| 69 | t_next_exp = t_next.view(-1, 1, 1, 1) |
| 70 | x_t = t_exp * x_1 + (1 - t_exp) * x_0 |
| 71 | |
| 72 | # Next points |
| 73 | x_next = t_next_exp * x_1 + (1 - t_next_exp) * x_0 |
| 74 | |
| 75 | # Compute f_theta and velocity terms |
| 76 | f_t = self.f_theta(t, x_t) |
| 77 | f_next = self.f_theta(t_next, x_next) |
| 78 | |
| 79 | v_t = self.velocity_field(t, x_t) |
| 80 | v_next = self.velocity_field(t_next, x_next) |
| 81 | |
| 82 | return f_t, f_next, v_t, v_next |
no outgoing calls