| 118 | """ |
| 119 | |
| 120 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, |
| 121 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, |
| 122 | remap=None, unknown_index="random"): |
| 123 | super().__init__() |
| 124 | |
| 125 | self.embedding_dim = embedding_dim |
| 126 | self.n_embed = n_embed |
| 127 | |
| 128 | self.straight_through = straight_through |
| 129 | self.temperature = temp_init |
| 130 | self.kl_weight = kl_weight |
| 131 | |
| 132 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) |
| 133 | self.embed = nn.Embedding(n_embed, embedding_dim) |
| 134 | |
| 135 | self.use_vqinterface = use_vqinterface |
| 136 | |
| 137 | self.remap = remap |
| 138 | if self.remap is not None: |
| 139 | self.register_buffer("used", torch.tensor(np.load(self.remap))) |
| 140 | self.re_embed = self.used.shape[0] |
| 141 | self.unknown_index = unknown_index # "random" or "extra" or integer |
| 142 | if self.unknown_index == "extra": |
| 143 | self.unknown_index = self.re_embed |
| 144 | self.re_embed = self.re_embed + 1 |
| 145 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " |
| 146 | f"Using {self.unknown_index} for unknown indices.") |
| 147 | else: |
| 148 | self.re_embed = n_embed |
| 149 | |
| 150 | def remap_to_used(self, inds): |
| 151 | ishape = inds.shape |