MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / sample_pdf

Function sample_pdf

modules/radnerfs/renderer.py:17–51  ·  view source on GitHub ↗
(bins, weights, n_samples, det=False)

Source from the content-addressed store, hash-verified

15
16
17def sample_pdf(bins, weights, n_samples, det=False):
18 # This implementation is from NeRF
19 # bins: [B, T], old_z_vals
20 # weights: [B, T - 1], bin weights.
21 # return: [B, n_samples], new_z_vals
22
23 # Get pdf
24 weights = weights + 1e-5 # prevent nans
25 pdf = weights / torch.sum(weights, -1, keepdim=True)
26 cdf = torch.cumsum(pdf, -1)
27 cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
28 # Take uniform samples
29 if det:
30 u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
31 u = u.expand(list(cdf.shape[:-1]) + [n_samples])
32 else:
33 u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
34
35 # Invert CDF
36 u = u.contiguous()
37 inds = torch.searchsorted(cdf, u, right=True)
38 below = torch.max(torch.zeros_like(inds - 1), inds - 1)
39 above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
40 inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
41
42 matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
43 cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
44 bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
45
46 denom = (cdf_g[..., 1] - cdf_g[..., 0])
47 denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
48 t = (u - cdf_g[..., 0]) / denom
49 samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
50
51 return samples
52
53
54def plot_pointcloud(pc, color=None):

Callers

nothing calls this directly

Calls 1

toMethod · 0.45

Tested by

no test coverage detected