(self, xs, ys, **kwargs)
| 95 | return loss, focal_loss, pull_loss, push_loss, regr_loss |
| 96 | |
| 97 | def validate(self, xs, ys, **kwargs): |
| 98 | with torch.no_grad(): |
| 99 | xs = [x.cuda(non_blocking=True) for x in xs] |
| 100 | ys = [y.cuda(non_blocking=True) for y in ys] |
| 101 | |
| 102 | loss_kp = self.network(xs, ys) |
| 103 | loss = loss_kp[0] |
| 104 | focal_loss = loss_kp[1] |
| 105 | pull_loss = loss_kp[2] |
| 106 | push_loss = loss_kp[3] |
| 107 | regr_loss = loss_kp[4] |
| 108 | loss = loss.mean() |
| 109 | return loss |
| 110 | |
| 111 | def test(self, xs, **kwargs): |
| 112 | with torch.no_grad(): |