MCPcopy
hub / github.com/hpcaitech/Open-Sora / test_lr_scheduler

Function test_lr_scheduler

tests/test_lr_scheduler.py:9–27  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

7
8
9def test_lr_scheduler():
10 warmup_steps = 200
11 model = resnet50().cuda()
12 optimizer = Adam(model.parameters(), lr=0.01)
13 scheduler = LinearWarmupLR(optimizer, warmup_steps=warmup_steps)
14 current_lr = scheduler.get_lr()[0]
15 data = torch.rand(1, 3, 224, 224).cuda()
16
17 for i in tqdm(range(warmup_steps * 2)):
18 out = model(data)
19 out.mean().backward()
20 optimizer.step()
21 scheduler.step()
22
23 if i >= warmup_steps:
24 assert scheduler.get_lr()[0] == 0.01
25 else:
26 assert scheduler.get_lr()[0] > current_lr, f"{scheduler.get_lr()[0]} <= {current_lr}"
27 current_lr = scheduler.get_lr()[0]
28
29
30if __name__ == "__main__":

Callers 1

Calls 4

get_lrMethod · 0.95
LinearWarmupLRClass · 0.90
tqdmFunction · 0.85
backwardMethod · 0.45

Tested by

no test coverage detected