(model, inputs)
| 91 | # also encapsulate these steps |
| 92 | # Note: inputs is a torch tensor |
| 93 | def predict(model, inputs): |
| 94 | inputs = Variable(inputs, requires_grad=False) |
| 95 | logits = model.forward(inputs) |
| 96 | return logits.data.numpy().argmax(axis=1) |
| 97 | |
| 98 | |
| 99 |
no test coverage detected