Path navigation overlay on a :class:`Checkpoint`. A ``Prefix`` carries a checkpoint reference plus a fully-qualified key prefix string (no trailing dot). Path arithmetic via ``+`` / ``append`` returns new ``Prefix`` objects. Tensor reads via ``get`` / ``has`` go through the underlyi
| 27 | |
| 28 | |
| 29 | class Prefix: |
| 30 | """Path navigation overlay on a :class:`Checkpoint`. |
| 31 | |
| 32 | A ``Prefix`` carries a checkpoint reference plus a fully-qualified |
| 33 | key prefix string (no trailing dot). Path arithmetic via ``+`` / |
| 34 | ``append`` returns new ``Prefix`` objects. Tensor reads via |
| 35 | ``get`` / ``has`` go through the underlying checkpoint. |
| 36 | """ |
| 37 | |
| 38 | __slots__ = ('ckpt', 'prefix') |
| 39 | |
| 40 | def __init__(self, ckpt: Checkpoint, prefix: str = ''): |
| 41 | self.ckpt = ckpt |
| 42 | self.prefix = prefix |
| 43 | |
| 44 | # ----- path navigation ----- |
| 45 | |
| 46 | def __add__(self, key) -> Prefix: |
| 47 | """``pfx + 'foo'`` -> Prefix at ``'parent.foo'`` (default '.' |
| 48 | separator). |
| 49 | |
| 50 | ``key`` may be ``str`` or ``int``; ints are stringified. |
| 51 | """ |
| 52 | return self.append(str(key)) |
| 53 | |
| 54 | def append(self, name: str, sep: str = '.') -> Prefix: |
| 55 | """Return a new Prefix with ``name`` appended via ``sep``. |
| 56 | |
| 57 | Empty current prefix or empty ``name`` skip the separator entirely. |
| 58 | """ |
| 59 | return Prefix(self.ckpt, self._joined(name, sep)) |
| 60 | |
| 61 | # ----- tensor access ----- |
| 62 | |
| 63 | def get(self, name: str = '', sep: str = '.', *, index=None) -> torch.Tensor: |
| 64 | """Read the tensor at ``self.prefix + sep + name``. |
| 65 | |
| 66 | Empty ``name`` reads the tensor at the exact prefix. Raises |
| 67 | ``KeyError`` on miss (delegates to checkpoint). |
| 68 | |
| 69 | If ``index`` is not None, the checkpoint slices the tensor along |
| 70 | dim 0 on CPU before transferring to GPU. |
| 71 | """ |
| 72 | return self.ckpt.get(self._joined(name, sep), index=index) |
| 73 | |
| 74 | def has(self, name: str = '', sep: str = '.') -> bool: |
| 75 | return self.ckpt.has(self._joined(name, sep)) |
| 76 | |
| 77 | def pop(self, name: str = '', sep: str = '.', *, index=None) -> torch.Tensor: |
| 78 | """Read and remove the tensor at ``self.prefix + sep + name``. |
| 79 | |
| 80 | Raises ``KeyError`` on miss. |
| 81 | |
| 82 | If ``index`` is not None, the checkpoint slices the tensor along |
| 83 | dim 0 on CPU before transferring to GPU. |
| 84 | """ |
| 85 | return self.ckpt.pop(self._joined(name, sep), index=index) |
| 86 |
no outgoing calls
no test coverage detected