()
| 232 | |
| 233 | |
| 234 | def test01(): |
| 235 | # conv1x1 vs linear |
| 236 | from sgm.util import count_params |
| 237 | |
| 238 | conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() |
| 239 | print(count_params(conv)) |
| 240 | linear = torch.nn.Linear(3, 32).cuda() |
| 241 | print(count_params(linear)) |
| 242 | |
| 243 | print(conv.weight.shape) |
| 244 | |
| 245 | # use same initialization |
| 246 | linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) |
| 247 | linear.bias = torch.nn.Parameter(conv.bias) |
| 248 | |
| 249 | print(linear.weight.shape) |
| 250 | |
| 251 | x = torch.randn(11, 3, 64, 64).cuda() |
| 252 | |
| 253 | xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() |
| 254 | print(xr.shape) |
| 255 | out_linear = linear(xr) |
| 256 | print(out_linear.mean(), out_linear.shape) |
| 257 | |
| 258 | out_conv = conv(x) |
| 259 | print(out_conv.mean(), out_conv.shape) |
| 260 | print("done with test01.\n") |
| 261 | |
| 262 | |
| 263 | def test02(): |
nothing calls this directly
no test coverage detected