| 158 | return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) |
| 159 | |
| 160 | def extract_fields(bound_min, bound_max, resolution, query_func): |
| 161 | N = 256 |
| 162 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) |
| 163 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) |
| 164 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) |
| 165 | |
| 166 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
| 167 | #with torch.no_grad(): |
| 168 | for xi, xs in enumerate(X): |
| 169 | for yi, ys in enumerate(Y): |
| 170 | for zi, zs in enumerate(Z): |
| 171 | xx, yy, zz = torch.meshgrid(xs, ys, zs) # for torch < 1.10, should remove indexing='ij' |
| 172 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [1, N, 3] |
| 173 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)) # [1, N, 1] --> [x, y, z] |
| 174 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val.detach().cpu().numpy() |
| 175 | del val |
| 176 | return u |
| 177 | |
| 178 | |
| 179 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, use_sdf = False): |