| 359 | return wrap(o) |
| 360 | |
| 361 | def to(self, *args, **kwargs) -> torch.Tensor: |
| 362 | if self._materialized_data is not None: |
| 363 | return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) |
| 364 | |
| 365 | device = None |
| 366 | |
| 367 | def replace(x): |
| 368 | nonlocal device |
| 369 | if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): |
| 370 | device = x |
| 371 | return torch.device("meta") |
| 372 | return x |
| 373 | |
| 374 | meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) |
| 375 | |
| 376 | if meta_data is self._meta_data and device == self.device: |
| 377 | return self |
| 378 | |
| 379 | def factory_fn(t: torch.Tensor, **kw): |
| 380 | return t.to(*args, **kwargs) |
| 381 | |
| 382 | return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) |
| 383 | |
| 384 | def cpu(self, memory_format: torch.memory_format = torch.preserve_format): |
| 385 | return self.to(device=torch.device("cpu"), memory_format=memory_format) |