(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear')
| 95 | |
| 96 | class GridEncoder(nn.Module): |
| 97 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): |
| 98 | super().__init__() |
| 99 | """ |
| 100 | resolution: number of grids on a axis. If this is a 2D grid, then the grid shape is [resolution,resolution] |
| 101 | log2_hashmap_size: log2 of the maximum number of grids in a level, for 2D grid, we have 2**log2_hashmap >= resolution ** 2 |
| 102 | """ |
| 103 | # the finest resolution desired at the last level, if provided, overridee per_level_scale |
| 104 | if desired_resolution is not None: |
| 105 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) |
| 106 | |
| 107 | self.input_dim = input_dim # coord dims, 2 or 3 |
| 108 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 |
| 109 | self.level_dim = level_dim # encode channels per level |
| 110 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. |
| 111 | self.log2_hashmap_size = log2_hashmap_size |
| 112 | self.base_resolution = base_resolution |
| 113 | self.output_dim = num_levels * level_dim |
| 114 | self.gridtype = gridtype |
| 115 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" |
| 116 | self.interpolation = interpolation |
| 117 | self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" |
| 118 | self.align_corners = align_corners |
| 119 | |
| 120 | # allocate parameters |
| 121 | offsets = [] |
| 122 | offset = 0 |
| 123 | self.max_params = 2 ** log2_hashmap_size |
| 124 | for i in range(num_levels): |
| 125 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) |
| 126 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number |
| 127 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible |
| 128 | offsets.append(offset) |
| 129 | offset += params_in_level |
| 130 | offsets.append(offset) |
| 131 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) |
| 132 | self.register_buffer('offsets', offsets) |
| 133 | |
| 134 | self.n_params = offsets[-1] * level_dim |
| 135 | |
| 136 | # parameters |
| 137 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) # save the embedding of each grid in each level. |
| 138 | |
| 139 | self.reset_parameters() |
| 140 | |
| 141 | def reset_parameters(self): |
| 142 | std = 1e-4 |
nothing calls this directly
no test coverage detected