(
self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs
)
| 1431 | collator_arglist = inspect.getfullargspec(GraphCollator).args |
| 1432 | |
| 1433 | def __init__( |
| 1434 | self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs |
| 1435 | ): |
| 1436 | collator_kwargs = {} |
| 1437 | dataloader_kwargs = {} |
| 1438 | for k, v in kwargs.items(): |
| 1439 | if k in self.collator_arglist: |
| 1440 | collator_kwargs[k] = v |
| 1441 | else: |
| 1442 | dataloader_kwargs[k] = v |
| 1443 | |
| 1444 | self.use_ddp = use_ddp |
| 1445 | if use_ddp: |
| 1446 | self.dist_sampler = _create_dist_sampler( |
| 1447 | dataset, dataloader_kwargs, ddp_seed |
| 1448 | ) |
| 1449 | dataloader_kwargs["sampler"] = self.dist_sampler |
| 1450 | |
| 1451 | if collate_fn is None and kwargs.get("batch_size", 1) is not None: |
| 1452 | collate_fn = GraphCollator(**collator_kwargs).collate |
| 1453 | |
| 1454 | super().__init__( |
| 1455 | dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs |
| 1456 | ) |
| 1457 | |
| 1458 | def set_epoch(self, epoch): |
| 1459 | """Sets the epoch number for the underlying sampler which ensures all replicas |
nothing calls this directly
no test coverage detected