MCPcopy
hub / github.com/Duankaiwen/CenterNet / validate

Method validate

nnet/py_factory.py:97–109  ·  view source on GitHub ↗
(self, xs, ys, **kwargs)

Source from the content-addressed store, hash-verified

95 return loss, focal_loss, pull_loss, push_loss, regr_loss
96
97 def validate(self, xs, ys, **kwargs):
98 with torch.no_grad():
99 xs = [x.cuda(non_blocking=True) for x in xs]
100 ys = [y.cuda(non_blocking=True) for y in ys]
101
102 loss_kp = self.network(xs, ys)
103 loss = loss_kp[0]
104 focal_loss = loss_kp[1]
105 pull_loss = loss_kp[2]
106 push_loss = loss_kp[3]
107 regr_loss = loss_kp[4]
108 loss = loss.mean()
109 return loss
110
111 def test(self, xs, **kwargs):
112 with torch.no_grad():

Callers 1

trainFunction · 0.95

Calls 2

cudaMethod · 0.80
meanMethod · 0.80

Tested by

no test coverage detected