MCPcopy
hub / github.com/openai/shap-e / forward

Method forward

shap_e/models/nerf/model.py:174–238  ·  view source on GitHub ↗
(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    )

Source from the content-addressed store, hash-verified

172 return h
173
174 def forward(
175 self,
176 query: Query,
177 params: Optional[Dict[str, torch.Tensor]] = None,
178 options: Optional[Dict[str, Any]] = None,
179 ) -> AttrDict:
180 params = self.update(params)
181
182 options = AttrDict() if options is None else AttrDict(options)
183
184 query = query.copy()
185
186 h_position = self.encode_position(query)
187
188 if self.meta_parameters:
189 density_params = subdict(params, "density_mlp")
190 density_mlp = partial(
191 self.density_mlp, params=density_params, options=options, log_prefix="density_"
192 )
193 density_mlp_parameters = list(density_params.values())
194 else:
195 density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
196 density_mlp_parameters = self.density_mlp.parameters()
197 h_density = checkpoint(
198 density_mlp,
199 (h_position,),
200 density_mlp_parameters,
201 options.checkpoint_nerf_mlp,
202 )
203 h_direction = maybe_get_spherical_harmonics_basis(
204 sh_degree=self.sh_degree,
205 coords_shape=query.position.shape,
206 coords=query.direction,
207 device=query.position.device,
208 )
209
210 if self.meta_parameters:
211 channel_params = subdict(params, "channel_mlp")
212 channel_mlp = partial(
213 self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
214 )
215 channel_mlp_parameters = list(channel_params.values())
216 else:
217 channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
218 channel_mlp_parameters = self.channel_mlp.parameters()
219 h_channel = checkpoint(
220 channel_mlp,
221 (torch.cat([h_density, h_direction], dim=-1),),
222 channel_mlp_parameters,
223 options.checkpoint_nerf_mlp,
224 )
225
226 density_logit = h_density[..., :1]
227
228 res = AttrDict(
229 density_logit=density_logit,
230 density=self.density_act(density_logit),
231 channels=self.act(h_channel),

Callers

nothing calls this directly

Calls 7

encode_positionMethod · 0.95
AttrDictClass · 0.90
subdictFunction · 0.90
checkpointFunction · 0.90
updateMethod · 0.80
copyMethod · 0.80

Tested by

no test coverage detected