| 12 | |
| 13 | # Meta Matching Network |
| 14 | class MetaMatchingNetwork(): |
| 15 | def __init__(self, cfg): |
| 16 | self.cfg = cfg |
| 17 | |
| 18 | # only update student network weights |
| 19 | def update_student_weights_only(self, random_cand, grad_1, optimizer, model): |
| 20 | for weight, grad_item in zip(model.module.rand_parameters(random_cand), grad_1): |
| 21 | weight.grad = grad_item |
| 22 | torch.nn.utils.clip_grad_norm_(model.module.rand_parameters(random_cand), 1) |
| 23 | optimizer.step() |
| 24 | for weight, grad_item in zip(model.module.rand_parameters(random_cand), grad_1): |
| 25 | del weight.grad |
| 26 | |
| 27 | # only update meta networks weights |
| 28 | def update_meta_weights_only(self, random_cand, teacher_cand, model, optimizer, grad_teacher): |
| 29 | for weight, grad_item in zip(model.module.rand_parameters( |
| 30 | teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), grad_teacher): |
| 31 | weight.grad = grad_item |
| 32 | |
| 33 | # clip gradients |
| 34 | torch.nn.utils.clip_grad_norm_( |
| 35 | model.module.rand_parameters( |
| 36 | teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), 1) |
| 37 | |
| 38 | optimizer.step() |
| 39 | for weight, grad_item in zip(model.module.rand_parameters( |
| 40 | teacher_cand, self.cfg.SUPERNET.PICK_METHOD == 'meta'), grad_teacher): |
| 41 | del weight.grad |
| 42 | |
| 43 | # simulate sgd updating |
| 44 | def simulate_sgd_update(self, w, g, optimizer): |
| 45 | return -g * optimizer.param_groups[-1]['lr'] + w |
| 46 | |
| 47 | # split training images into several slices |
| 48 | def get_minibatch_input(self, input): |
| 49 | slice = self.cfg.SUPERNET.SLICE |
| 50 | x = deepcopy(input[:slice].clone().detach()) |
| 51 | return x |
| 52 | |
| 53 | def calculate_1st_gradient(self, kd_loss, model, random_cand, optimizer): |
| 54 | optimizer.zero_grad() |
| 55 | grad = torch.autograd.grad( |
| 56 | kd_loss, |
| 57 | model.module.rand_parameters(random_cand), |
| 58 | create_graph=True) |
| 59 | return grad |
| 60 | |
| 61 | def calculate_2nd_gradient(self, validation_loss, model, optimizer, random_cand, teacher_cand, students_weight): |
| 62 | optimizer.zero_grad() |
| 63 | grad_student_val = torch.autograd.grad( |
| 64 | validation_loss, model.module.rand_parameters(random_cand), retain_graph=True) |
| 65 | |
| 66 | grad_teacher = torch.autograd.grad( |
| 67 | students_weight[0], |
| 68 | model.module.rand_parameters( |
| 69 | teacher_cand, |
| 70 | self.cfg.SUPERNET.PICK_METHOD == 'meta'), |
| 71 | grad_outputs=grad_student_val) |