MCPcopy
hub / github.com/dmlc/dgl / GraphCollator

Class GraphCollator

python/dgl/dataloading/dataloader.py:1279–1373  ·  view source on GitHub ↗

Given a set of graphs as well as their graph-level data, the collate function will batch the graphs into a batched graph, and stack the tensors into a single bigger tensor. If the example is a container (such as sequences or mapping), the collate function preserves the structure and col

Source from the content-addressed store, hash-verified

1277
1278
1279class GraphCollator(object):
1280 """Given a set of graphs as well as their graph-level data, the collate function will batch the
1281 graphs into a batched graph, and stack the tensors into a single bigger tensor. If the
1282 example is a container (such as sequences or mapping), the collate function preserves
1283 the structure and collates each of the elements recursively.
1284
1285 If the set of graphs has no graph-level data, the collate function will yield a batched graph.
1286
1287 Examples
1288 --------
1289 To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
1290 the backend is PyTorch):
1291
1292 >>> dataloader = dgl.dataloading.GraphDataLoader(
1293 ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
1294 >>> for batched_graph, labels in dataloader:
1295 ... train_on(batched_graph, labels)
1296 """
1297
1298 def __init__(self):
1299 self.graph_collate_err_msg_format = (
1300 "graph_collate: batch must contain DGLGraph, tensors, numpy arrays, "
1301 "numbers, dicts or lists; found {}"
1302 )
1303 self.np_str_obj_array_pattern = re.compile(r"[SaUO]")
1304
1305 # This implementation is based on torch.utils.data._utils.collate.default_collate
1306 def collate(self, items):
1307 """This function is similar to ``torch.utils.data._utils.collate.default_collate``.
1308 It combines the sampled graphs and corresponding graph-level data
1309 into a batched graph and tensors.
1310
1311 Parameters
1312 ----------
1313 items : list of data points or tuples
1314 Elements in the list are expected to have the same length.
1315 Each sub-element will be batched as a batched graph, or a
1316 batched tensor correspondingly.
1317
1318 Returns
1319 -------
1320 A tuple of the batching results.
1321 """
1322 elem = items[0]
1323 elem_type = type(elem)
1324 if isinstance(elem, DGLGraph):
1325 batched_graphs = batch_graphs(items)
1326 return batched_graphs
1327 elif F.is_tensor(elem):
1328 return F.stack(items, 0)
1329 elif (
1330 elem_type.__module__ == "numpy"
1331 and elem_type.__name__ != "str_"
1332 and elem_type.__name__ != "string_"
1333 ):
1334 if (
1335 elem_type.__name__ == "ndarray"
1336 or elem_type.__name__ == "memmap"

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected