Parameters: ptr: GPU memory address obtained from cudaMalloc (Python int) segment_size: Memory size of each segments in bytes segment_stride: Memory stride size between segments in bytes num_segments: Number of segments torch_dtype: torch dtype dev_id: device
(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id)
| 107 | |
| 108 | |
| 109 | def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id): |
| 110 | """ |
| 111 | Parameters: |
| 112 | ptr: GPU memory address obtained from cudaMalloc (Python int) |
| 113 | segment_size: Memory size of each segments in bytes |
| 114 | segment_stride: Memory stride size between segments in bytes |
| 115 | num_segments: Number of segments |
| 116 | torch_dtype: torch dtype |
| 117 | dev_id: device id. |
| 118 | Returns: |
| 119 | A PyCapsule object compliant with DLPack specification, which can be directly converted to a |
| 120 | tensor using torch.utils.dlpack.from_dlpack |
| 121 | """ |
| 122 | bits_per_elements = 0 |
| 123 | dldata_type_code = 0 |
| 124 | # refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160 |
| 125 | if torch_dtype in [ |
| 126 | torch.float8_e5m2, |
| 127 | torch.float8_e4m3fn, |
| 128 | torch.bfloat16, |
| 129 | torch.float16, |
| 130 | torch.float32, |
| 131 | torch.float64, |
| 132 | ]: |
| 133 | bits_per_elements = torch.finfo(torch_dtype).bits |
| 134 | dldata_type_code = 2 |
| 135 | elif torch_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: |
| 136 | bits_per_elements = torch.iinfo(torch_dtype).bits |
| 137 | dldata_type_code = 0 |
| 138 | elif torch_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]: |
| 139 | bits_per_elements = torch.iinfo(torch_dtype).bits |
| 140 | dldata_type_code = 1 |
| 141 | else: |
| 142 | raise NotImplementedError(torch_dtype) |
| 143 | bytes_per_element = bits_per_elements // 8 |
| 144 | # Allocate space for shape (constructing a one-dimensional tensor here) |
| 145 | ShapeArrayType = c_int64 * 2 # 1 dimension |
| 146 | shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element) |
| 147 | stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1) |
| 148 | # Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed) |
| 149 | device = DLDevice(device_type=2, device_id=dev_id) |
| 150 | # Set data type |
| 151 | dtype = DLDataType(code=dldata_type_code, bits=bits_per_elements, lanes=1) |
| 152 | # Construct DLTensor |
| 153 | dltensor = DLTensor() |
| 154 | dltensor.data = c_void_p(ptr) |
| 155 | dltensor.device = device |
| 156 | dltensor.ndim = 2 |
| 157 | dltensor.dtype = dtype |
| 158 | dltensor.shape = ctypes.cast(shape_array, POINTER(c_int64)) |
| 159 | dltensor.strides = ctypes.cast(stride_array, POINTER(c_int64)) |
| 160 | dltensor.byte_offset = 0 |
| 161 | # Construct DLManagedTensor and set deleter to no-op (you can also call cudaFree here) |
| 162 | managed_tensor = DLManagedTensor() |
| 163 | managed_tensor.dl_tensor = dltensor |
| 164 | managed_tensor.manager_ctx = None |
| 165 | managed_tensor.deleter = no_op_deleter |
| 166 | # Note: Must ensure that shape_array and managed_tensor are not garbage collected by Python, |
no test coverage detected