(file_name, model, use_cuda)
| 63 | |
| 64 | @staticmethod |
| 65 | def _load_checkpoint(file_name, model, use_cuda): |
| 66 | if use_cuda: |
| 67 | checkpoint = torch.load(file_name, weights_only=True) |
| 68 | else: |
| 69 | checkpoint = torch.load(file_name, weights_only=True, map_location=lambda storage, loc: storage) |
| 70 | model.load_state_dict(checkpoint["state_dict"]) |
| 71 | |
| 72 | def predict(self, texts): |
| 73 | """ |