MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / forward

Method forward

mGPT/archs/tools/quantize_cnn.py:139–165  ·  view source on GitHub ↗
(self, z)

Source from the content-addressed store, hash-verified

137 self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
138
139 def forward(self, z):
140
141 N, width, T = z.shape
142 z = self.preprocess(z)
143 assert z.shape[-1] == self.e_dim
144 z_flattened = z.contiguous().view(-1, self.e_dim)
145
146 # B x V
147 d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
148 torch.sum(self.embedding.weight**2, dim=1) - 2 * \
149 torch.matmul(z_flattened, self.embedding.weight.t())
150 # B x 1
151 min_encoding_indices = torch.argmin(d, dim=1)
152 z_q = self.embedding(min_encoding_indices).view(z.shape)
153
154 # compute loss for embedding
155 loss = torch.mean((z_q - z.detach())**2) + self.beta * \
156 torch.mean((z_q.detach() - z)**2)
157
158 # preserve gradients
159 z_q = z + (z_q - z).detach()
160 z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T)
161
162 min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
163 e_mean = torch.mean(min_encodings, dim=0)
164 perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10)))
165 return z_q, loss, perplexity
166
167 def quantize(self, z):
168

Callers

nothing calls this directly

Calls 2

preprocessMethod · 0.95
detachMethod · 0.80

Tested by

no test coverage detected