(bound_min, bound_max, resolution, query_func, S=128)
| 398 | |
| 399 | |
| 400 | def extract_fields(bound_min, bound_max, resolution, query_func, S=128): |
| 401 | |
| 402 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) |
| 403 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) |
| 404 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) |
| 405 | |
| 406 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
| 407 | with torch.no_grad(): |
| 408 | for xi, xs in enumerate(X): |
| 409 | for yi, ys in enumerate(Y): |
| 410 | for zi, zs in enumerate(Z): |
| 411 | xx, yy, zz = custom_meshgrid(xs, ys, zs) |
| 412 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] |
| 413 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] |
| 414 | u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val |
| 415 | return u |
| 416 | |
| 417 | |
| 418 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): |
no test coverage detected