MCPcopy Index your code
hub / github.com/microsoft/Cream / MetaMatchingNetwork

Class MetaMatchingNetwork

Cream/lib/models/MetaMatchingNetwork.py:14–130  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13# Meta Matching Network
14class MetaMatchingNetwork():
15 def __init__(self, cfg):
16 self.cfg = cfg
17
18 # only update student network weights
19 def update_student_weights_only(self, random_cand, grad_1, optimizer, model):
20 for weight, grad_item in zip(model.module.rand_parameters(random_cand), grad_1):
21 weight.grad = grad_item
22 torch.nn.utils.clip_grad_norm_(model.module.rand_parameters(random_cand), 1)
23 optimizer.step()
24 for weight, grad_item in zip(model.module.rand_parameters(random_cand), grad_1):
25 del weight.grad
26
27 # only update meta networks weights
28 def update_meta_weights_only(self, random_cand, teacher_cand, model, optimizer, grad_teacher):
29 for weight, grad_item in zip(model.module.rand_parameters(
30 teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), grad_teacher):
31 weight.grad = grad_item
32
33 # clip gradients
34 torch.nn.utils.clip_grad_norm_(
35 model.module.rand_parameters(
36 teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), 1)
37
38 optimizer.step()
39 for weight, grad_item in zip(model.module.rand_parameters(
40 teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), grad_teacher):
41 del weight.grad
42
43 # simulate sgd updating
44 def simulate_sgd_update(self, w, g, optimizer):
45 return -g * optimizer.param_groups[-1]['lr'] + w
46
47 # split training images into several slices
48 def get_minibatch_input(self, input):
49 slice = self.cfg.SUPERNET.SLICE
50 x = deepcopy(input[:slice].clone().detach())
51 return x
52
53 def calculate_1st_gradient(self, kd_loss, model, random_cand, optimizer):
54 optimizer.zero_grad()
55 grad = torch.autograd.grad(
56 kd_loss,
57 model.module.rand_parameters(random_cand),
58 create_graph=True)
59 return grad
60
61 def calculate_2nd_gradient(self, validation_loss, model, optimizer, random_cand, teacher_cand, students_weight):
62 optimizer.zero_grad()
63 grad_student_val = torch.autograd.grad(
64 validation_loss, model.module.rand_parameters(random_cand), retain_graph=True)
65
66 grad_teacher = torch.autograd.grad(
67 students_weight[0],
68 model.module.rand_parameters(
69 teacher_cand,
70 self.cfg.SUPERNET.PICK_METHOD == 'meta'),
71 grad_outputs=grad_student_val)

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected