MCPcopy
hub / github.com/Stability-AI/generative-models / test01

Function test01

scripts/tests/attention.py:234–260  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

232
233
234def test01():
235 # conv1x1 vs linear
236 from sgm.util import count_params
237
238 conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
239 print(count_params(conv))
240 linear = torch.nn.Linear(3, 32).cuda()
241 print(count_params(linear))
242
243 print(conv.weight.shape)
244
245 # use same initialization
246 linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
247 linear.bias = torch.nn.Parameter(conv.bias)
248
249 print(linear.weight.shape)
250
251 x = torch.randn(11, 3, 64, 64).cuda()
252
253 xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
254 print(xr.shape)
255 out_linear = linear(xr)
256 print(out_linear.mean(), out_linear.shape)
257
258 out_conv = conv(x)
259 print(out_conv.mean(), out_conv.shape)
260 print("done with test01.\n")
261
262
263def test02():

Callers

nothing calls this directly

Calls 2

count_paramsFunction · 0.90
linearFunction · 0.85

Tested by

no test coverage detected