| 146 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" |
| 147 | |
| 148 | def forward(self, inputs, bound=1): |
| 149 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] |
| 150 | # return: [..., num_levels * level_dim] |
| 151 | |
| 152 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] |
| 153 | |
| 154 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) |
| 155 | |
| 156 | prefix_shape = list(inputs.shape[:-1]) |
| 157 | inputs = inputs.view(-1, self.input_dim) |
| 158 | |
| 159 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) |
| 160 | outputs = outputs.view(prefix_shape + [self.output_dim]) |
| 161 | |
| 162 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) |
| 163 | |
| 164 | return outputs |
| 165 | |
| 166 | # always run in float precision! |
| 167 | @torch.cuda.amp.autocast(enabled=False) |