Given a set of graphs as well as their graph-level data, the collate function will batch the graphs into a batched graph, and stack the tensors into a single bigger tensor. If the example is a container (such as sequences or mapping), the collate function preserves the structure and col
| 1277 | |
| 1278 | |
| 1279 | class GraphCollator(object): |
| 1280 | """Given a set of graphs as well as their graph-level data, the collate function will batch the |
| 1281 | graphs into a batched graph, and stack the tensors into a single bigger tensor. If the |
| 1282 | example is a container (such as sequences or mapping), the collate function preserves |
| 1283 | the structure and collates each of the elements recursively. |
| 1284 | |
| 1285 | If the set of graphs has no graph-level data, the collate function will yield a batched graph. |
| 1286 | |
| 1287 | Examples |
| 1288 | -------- |
| 1289 | To train a GNN for graph classification on a set of graphs in ``dataset`` (assume |
| 1290 | the backend is PyTorch): |
| 1291 | |
| 1292 | >>> dataloader = dgl.dataloading.GraphDataLoader( |
| 1293 | ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) |
| 1294 | >>> for batched_graph, labels in dataloader: |
| 1295 | ... train_on(batched_graph, labels) |
| 1296 | """ |
| 1297 | |
| 1298 | def __init__(self): |
| 1299 | self.graph_collate_err_msg_format = ( |
| 1300 | "graph_collate: batch must contain DGLGraph, tensors, numpy arrays, " |
| 1301 | "numbers, dicts or lists; found {}" |
| 1302 | ) |
| 1303 | self.np_str_obj_array_pattern = re.compile(r"[SaUO]") |
| 1304 | |
| 1305 | # This implementation is based on torch.utils.data._utils.collate.default_collate |
| 1306 | def collate(self, items): |
| 1307 | """This function is similar to ``torch.utils.data._utils.collate.default_collate``. |
| 1308 | It combines the sampled graphs and corresponding graph-level data |
| 1309 | into a batched graph and tensors. |
| 1310 | |
| 1311 | Parameters |
| 1312 | ---------- |
| 1313 | items : list of data points or tuples |
| 1314 | Elements in the list are expected to have the same length. |
| 1315 | Each sub-element will be batched as a batched graph, or a |
| 1316 | batched tensor correspondingly. |
| 1317 | |
| 1318 | Returns |
| 1319 | ------- |
| 1320 | A tuple of the batching results. |
| 1321 | """ |
| 1322 | elem = items[0] |
| 1323 | elem_type = type(elem) |
| 1324 | if isinstance(elem, DGLGraph): |
| 1325 | batched_graphs = batch_graphs(items) |
| 1326 | return batched_graphs |
| 1327 | elif F.is_tensor(elem): |
| 1328 | return F.stack(items, 0) |
| 1329 | elif ( |
| 1330 | elem_type.__module__ == "numpy" |
| 1331 | and elem_type.__name__ != "str_" |
| 1332 | and elem_type.__name__ != "string_" |
| 1333 | ): |
| 1334 | if ( |
| 1335 | elem_type.__name__ == "ndarray" |
| 1336 | or elem_type.__name__ == "memmap" |