(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
)
| 172 | return h |
| 173 | |
| 174 | def forward( |
| 175 | self, |
| 176 | query: Query, |
| 177 | params: Optional[Dict[str, torch.Tensor]] = None, |
| 178 | options: Optional[Dict[str, Any]] = None, |
| 179 | ) -> AttrDict: |
| 180 | params = self.update(params) |
| 181 | |
| 182 | options = AttrDict() if options is None else AttrDict(options) |
| 183 | |
| 184 | query = query.copy() |
| 185 | |
| 186 | h_position = self.encode_position(query) |
| 187 | |
| 188 | if self.meta_parameters: |
| 189 | density_params = subdict(params, "density_mlp") |
| 190 | density_mlp = partial( |
| 191 | self.density_mlp, params=density_params, options=options, log_prefix="density_" |
| 192 | ) |
| 193 | density_mlp_parameters = list(density_params.values()) |
| 194 | else: |
| 195 | density_mlp = partial(self.density_mlp, options=options, log_prefix="density_") |
| 196 | density_mlp_parameters = self.density_mlp.parameters() |
| 197 | h_density = checkpoint( |
| 198 | density_mlp, |
| 199 | (h_position,), |
| 200 | density_mlp_parameters, |
| 201 | options.checkpoint_nerf_mlp, |
| 202 | ) |
| 203 | h_direction = maybe_get_spherical_harmonics_basis( |
| 204 | sh_degree=self.sh_degree, |
| 205 | coords_shape=query.position.shape, |
| 206 | coords=query.direction, |
| 207 | device=query.position.device, |
| 208 | ) |
| 209 | |
| 210 | if self.meta_parameters: |
| 211 | channel_params = subdict(params, "channel_mlp") |
| 212 | channel_mlp = partial( |
| 213 | self.channel_mlp, params=channel_params, options=options, log_prefix="channel_" |
| 214 | ) |
| 215 | channel_mlp_parameters = list(channel_params.values()) |
| 216 | else: |
| 217 | channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_") |
| 218 | channel_mlp_parameters = self.channel_mlp.parameters() |
| 219 | h_channel = checkpoint( |
| 220 | channel_mlp, |
| 221 | (torch.cat([h_density, h_direction], dim=-1),), |
| 222 | channel_mlp_parameters, |
| 223 | options.checkpoint_nerf_mlp, |
| 224 | ) |
| 225 | |
| 226 | density_logit = h_density[..., :1] |
| 227 | |
| 228 | res = AttrDict( |
| 229 | density_logit=density_logit, |
| 230 | density=self.density_act(density_logit), |
| 231 | channels=self.act(h_channel), |
nothing calls this directly
no test coverage detected