get rays Args: poses: [B, 4, 4], cam2world intrinsics: [4] H, W, N: int Returns: rays_o, rays_d: [B, N, 3] inds: [B, N]
(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None)
| 281 | |
| 282 | |
| 283 | def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None): |
| 284 | ''' get rays |
| 285 | Args: |
| 286 | poses: [B, 4, 4], cam2world |
| 287 | intrinsics: [4] |
| 288 | H, W, N: int |
| 289 | Returns: |
| 290 | rays_o, rays_d: [B, N, 3] |
| 291 | inds: [B, N] |
| 292 | ''' |
| 293 | |
| 294 | device = poses.device |
| 295 | B = poses.shape[0] |
| 296 | fx, fy, cx, cy = intrinsics |
| 297 | |
| 298 | if rect is not None: |
| 299 | xmin, xmax, ymin, ymax = rect |
| 300 | N = (xmax - xmin) * (ymax - ymin) |
| 301 | |
| 302 | i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float |
| 303 | i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 |
| 304 | j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 |
| 305 | |
| 306 | results = {} |
| 307 | |
| 308 | if N > 0: |
| 309 | N = min(N, H*W) |
| 310 | |
| 311 | if patch_size > 1: |
| 312 | |
| 313 | # random sample left-top cores. |
| 314 | # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas. |
| 315 | num_patch = N // (patch_size ** 2) |
| 316 | inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device) |
| 317 | inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device) |
| 318 | inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] |
| 319 | |
| 320 | # create meshgrid for each patch |
| 321 | pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device)) |
| 322 | offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] |
| 323 | |
| 324 | inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] |
| 325 | inds = inds.view(-1, 2) # [N, 2] |
| 326 | inds = inds[:, 0] * W + inds[:, 1] # [N], flatten |
| 327 | |
| 328 | inds = inds.expand([B, N]) |
| 329 | |
| 330 | # only get rays in the specified rect |
| 331 | elif rect is not None: |
| 332 | # assert B == 1 |
| 333 | mask = torch.zeros(H, W, dtype=torch.bool, device=device) |
| 334 | xmin, xmax, ymin, ymax = rect |
| 335 | mask[xmin:xmax, ymin:ymax] = 1 |
| 336 | inds = torch.where(mask.view(-1))[0] # [nzn] |
| 337 | inds = inds.unsqueeze(0) # [1, N] |
| 338 | |
| 339 | else: |
| 340 | inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate |