MCPcopy
hub / github.com/HKUDS/AI-Researcher / CNF

Class CNF

examples/con_flowmatching/project/model/network.py:31–82  ·  view source on GitHub ↗

Continuous Normalizing Flow model

Source from the content-addressed store, hash-verified

29 return v.view(x.shape[0], 3, 32, 32) # Reshape back to image
30
31class CNF(nn.Module):
32 """Continuous Normalizing Flow model"""
33 def __init__(self, velocity_net, sigma_min=0.1, n_steps=100):
34 super().__init__()
35 self.velocity_net = velocity_net
36 self.sigma_min = sigma_min
37 self.n_steps = n_steps
38
39 def f_theta(self, t, x_t):
40 """Compute f_θ(t, x_t) = x_t + (1-t)v_θ(t, x_t)"""
41 v = self.velocity_net(x_t, t)
42 t = t.view(-1, 1, 1, 1).expand(-1, 3, 32, 32)
43 return x_t + (1-t)*v
44
45 def velocity_field(self, t, x_t):
46 """Compute velocity field v_θ(t, x_t)"""
47 return self.velocity_net(x_t, t)
48
49 def sample(self, n_samples, device):
50 """Generate samples using Euler method"""
51 # Sample from base distribution
52 x = torch.randn(n_samples, 3, 32, 32).to(device)
53
54 dt = 1.0 / self.n_steps
55 for i in range(self.n_steps):
56 t = torch.ones(n_samples, device=device) * (i * dt)
57 v = self.velocity_field(t, x)
58 x = x + dt * v
59
60 return x
61
62 def forward(self, x_0, x_1, t):
63 """Forward pass computing velocity consistency loss"""
64 dt = 0.01 # Small time step for consistency loss
65 t_next = torch.clamp(t + dt, max=1.0)
66
67 # Current points
68 t_exp = t.view(-1, 1, 1, 1)
69 t_next_exp = t_next.view(-1, 1, 1, 1)
70 x_t = t_exp * x_1 + (1 - t_exp) * x_0
71
72 # Next points
73 x_next = t_next_exp * x_1 + (1 - t_next_exp) * x_0
74
75 # Compute f_theta and velocity terms
76 f_t = self.f_theta(t, x_t)
77 f_next = self.f_theta(t_next, x_next)
78
79 v_t = self.velocity_field(t, x_t)
80 v_next = self.velocity_field(t_next, x_next)
81
82 return f_t, f_next, v_t, v_next

Callers 6

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
run_experimentFunction · 0.90
run_ablation_experimentFunction · 0.90

Calls

no outgoing calls

Tested by 2

mainFunction · 0.72
mainFunction · 0.72