Return the model on the given device. Raises: KeyError: If a device is specified that was not included in the last call to .to()
(self, device: torch.device)
| 1431 | self._models = {device: get_on_device(device) for device in devices} |
| 1432 | |
| 1433 | def get(self, device: torch.device) -> Architecture: |
| 1434 | """Return the model on the given device. |
| 1435 | |
| 1436 | Raises: |
| 1437 | KeyError: If a device is specified that was not included in the last call to |
| 1438 | .to() |
| 1439 | """ |
| 1440 | return self._models[device] |
| 1441 | |
| 1442 | def set_dtype(self, dtype: torch.dtype) -> None: |
| 1443 | """Set the dtype of the model's parameters.""" |
no outgoing calls