(w, t=True)
| 205 | import numpy as np |
| 206 | |
| 207 | def _n2p(w, t=True): |
| 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
| 209 | w = w.flatten() |
| 210 | if t: |
| 211 | if w.ndim == 4: |
| 212 | w = w.transpose([3, 2, 0, 1]) |
| 213 | elif w.ndim == 3: |
| 214 | w = w.transpose([2, 0, 1]) |
| 215 | elif w.ndim == 2: |
| 216 | w = w.transpose([1, 0]) |
| 217 | return torch.from_numpy(w) |
| 218 | |
| 219 | w = np.load(checkpoint_path) |
| 220 | if not prefix and 'opt/target/embedding/kernel' in w: |