| 303 | |
| 304 | class VectorQuantizer4(nn.Module): |
| 305 | def __init__(self, n_e, e_dim, beta, legacy=False, kmeans_reset_every=1000): |
| 306 | super().__init__() |
| 307 | self.n_e = n_e |
| 308 | self.e_dim = e_dim |
| 309 | self.beta = beta |
| 310 | self.legacy = legacy |
| 311 | |
| 312 | self.embedding = nn.Embedding(self.n_e, self.e_dim) |
| 313 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| 314 | |
| 315 | self.re_embed = n_e |
| 316 | self.reset_every = kmeans_reset_every |
| 317 | self.reset_thres = 20 |
| 318 | self.z_buffer = [] |
| 319 | self.register_buffer('use_flag', torch.zeros(n_e)) |
| 320 | self.register_buffer('steps', torch.zeros(1)) |
| 321 | |
| 322 | def encode(self, z): |
| 323 | B, T, _ = z.shape |