Evaluate the model on the validation and test set. Parameters ---------- model : DistSAGE The model to be evaluated. g : DistGraph The entire graph. inputs : DistTensor The feature data of all the nodes. labels : DistTensor The labels of
(model, g, inputs, labels, val_nid, test_nid, batch_size, device)
| 156 | |
| 157 | |
| 158 | def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): |
| 159 | """ |
| 160 | Evaluate the model on the validation and test set. |
| 161 | |
| 162 | Parameters |
| 163 | ---------- |
| 164 | model : DistSAGE |
| 165 | The model to be evaluated. |
| 166 | g : DistGraph |
| 167 | The entire graph. |
| 168 | inputs : DistTensor |
| 169 | The feature data of all the nodes. |
| 170 | labels : DistTensor |
| 171 | The labels of all the nodes. |
| 172 | val_nid : torch.Tensor |
| 173 | The node IDs for validation. |
| 174 | test_nid : torch.Tensor |
| 175 | The node IDs for test. |
| 176 | batch_size : int |
| 177 | Batch size for evaluation. |
| 178 | device : torch.Device |
| 179 | The target device to evaluate on. |
| 180 | |
| 181 | Returns |
| 182 | ------- |
| 183 | float |
| 184 | Validation accuracy. |
| 185 | float |
| 186 | Test accuracy. |
| 187 | """ |
| 188 | model.eval() |
| 189 | with th.no_grad(): |
| 190 | pred = model.inference(g, inputs, batch_size, device) |
| 191 | model.train() |
| 192 | return compute_acc(pred[val_nid], labels[val_nid]), compute_acc( |
| 193 | pred[test_nid], labels[test_nid] |
| 194 | ) |
| 195 | |
| 196 | |
| 197 | def run(args, device, data): |
no test coverage detected