(batch: AttrDict, *, pc_scale: float, color_scale: float)
| 213 | |
| 214 | |
| 215 | def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict: |
| 216 | res = batch.copy() |
| 217 | scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device) |
| 218 | res.points = res.points * scale_vec[:, None] |
| 219 | |
| 220 | if "cameras" in res: |
| 221 | res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras] |
| 222 | |
| 223 | if "depths" in res: |
| 224 | res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths] |
| 225 | |
| 226 | return res |
| 227 | |
| 228 | |
| 229 | def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray: |
no test coverage detected