Recursively materialize a stored value, creating tensors from metadata and keeping non-tensors as-is
(self, stored_value: Any)
| 79 | return value |
| 80 | |
| 81 | def _materialize_value(self, stored_value: Any) -> Any: |
| 82 | """ |
| 83 | Recursively materialize a stored value, creating tensors from metadata and keeping non-tensors as-is |
| 84 | """ |
| 85 | if isinstance(stored_value, TensorMetadata): |
| 86 | # If we have real data stored, use it |
| 87 | if stored_value.real_data is not None: |
| 88 | try: |
| 89 | # Use the stored real data |
| 90 | tensor = stored_value.real_data.clone() |
| 91 | |
| 92 | # Set stride if different from default and tensor is contiguous |
| 93 | if tensor.stride() != stored_value.stride and len(stored_value.shape) > 0: |
| 94 | try: |
| 95 | # Create tensor with specific stride |
| 96 | tensor = torch.as_strided(tensor, stored_value.shape, stored_value.stride, |
| 97 | stored_value.storage_offset) |
| 98 | except RuntimeError: |
| 99 | # If stride setting fails, use default stride |
| 100 | pass |
| 101 | |
| 102 | # Move to target device and set requires_grad |
| 103 | tensor = tensor.to(device=stored_value.device) |
| 104 | tensor.requires_grad_(stored_value.requires_grad) |
| 105 | |
| 106 | return tensor |
| 107 | |
| 108 | except Exception as e: |
| 109 | # Fallback to dummy data if real data fails |
| 110 | pass |
| 111 | |
| 112 | # Create a tensor with the stored metadata (original behavior for non-int tensors) |
| 113 | # Use CPU first to avoid GPU memory issues, then move to target device |
| 114 | try: |
| 115 | tensor = torch.empty(stored_value.shape, |
| 116 | dtype=stored_value.dtype, |
| 117 | layout=stored_value.layout, |
| 118 | device='cpu') |
| 119 | |
| 120 | # Fill with dummy data (ones) for profiling purposes |
| 121 | tensor.fill_(1.0) |
| 122 | |
| 123 | # Set stride if different from default and tensor is contiguous |
| 124 | if tensor.stride() != stored_value.stride and len(stored_value.shape) > 0: |
| 125 | try: |
| 126 | # Create tensor with specific stride |
| 127 | tensor = torch.as_strided(tensor, stored_value.shape, stored_value.stride, |
| 128 | stored_value.storage_offset) |
| 129 | except RuntimeError: |
| 130 | # If stride setting fails, use default stride |
| 131 | pass |
| 132 | |
| 133 | # Move to target device and set requires_grad |
| 134 | tensor = tensor.to(device=stored_value.device) |
| 135 | tensor.requires_grad_(stored_value.requires_grad) |
| 136 | |
| 137 | return tensor |
| 138 |