Convert the data to ID tensor and check its ID type and context. If the data is already in tensor type, raise error if its ID type and context does not match the graph's. Otherwise, convert it to tensor type of the graph's ID type and ctx and return. Parameters ----------
(g, data, name)
| 10 | |
| 11 | |
| 12 | def prepare_tensor(g, data, name): |
| 13 | """Convert the data to ID tensor and check its ID type and context. |
| 14 | |
| 15 | If the data is already in tensor type, raise error if its ID type |
| 16 | and context does not match the graph's. |
| 17 | Otherwise, convert it to tensor type of the graph's ID type and |
| 18 | ctx and return. |
| 19 | |
| 20 | Parameters |
| 21 | ---------- |
| 22 | g : DGLGraph |
| 23 | Graph. |
| 24 | data : int, iterable of int, tensor |
| 25 | Data. |
| 26 | name : str |
| 27 | Name of the data. |
| 28 | |
| 29 | Returns |
| 30 | ------- |
| 31 | Tensor |
| 32 | Data in tensor object. |
| 33 | """ |
| 34 | if F.is_tensor(data): |
| 35 | if F.dtype(data) != g.idtype: |
| 36 | raise DGLError( |
| 37 | f'Expect argument "{name}" to have data type {g.idtype}. ' |
| 38 | f"But got {F.dtype(data)}." |
| 39 | ) |
| 40 | if F.context(data) != g.device and not g.is_pinned(): |
| 41 | raise DGLError( |
| 42 | f'Expect argument "{name}" to have device {g.device}. ' |
| 43 | f"But got {F.context(data)}." |
| 44 | ) |
| 45 | ret = data |
| 46 | else: |
| 47 | data = F.tensor(data) |
| 48 | if not ( |
| 49 | F.ndim(data) > 0 and F.shape(data)[0] == 0 |
| 50 | ) and F.dtype( # empty tensor |
| 51 | data |
| 52 | ) not in ( |
| 53 | F.int32, |
| 54 | F.int64, |
| 55 | ): |
| 56 | raise DGLError( |
| 57 | 'Expect argument "{}" to have data type int32 or int64,' |
| 58 | " but got {}.".format(name, F.dtype(data)) |
| 59 | ) |
| 60 | ret = F.copy_to(F.astype(data, g.idtype), g.device) |
| 61 | |
| 62 | if F.ndim(ret) == 0: |
| 63 | ret = F.unsqueeze(ret, 0) |
| 64 | if F.ndim(ret) > 1: |
| 65 | raise DGLError( |
| 66 | 'Expect a 1-D tensor for argument "{}". But got {}.'.format( |
| 67 | name, ret |
| 68 | ) |
| 69 | ) |
no test coverage detected