| 216 | self.cnt[0] = 0 |
| 217 | |
| 218 | def sample_blocks(self, g, seed_nodes, exclude_eids=None): |
| 219 | output_nodes = seed_nodes |
| 220 | blocks = [] |
| 221 | for i, fanout in enumerate(reversed(self.fanouts)): |
| 222 | random_seed_i = F.zerocopy_to_dgl_ndarray( |
| 223 | self.random_seed + (i if not self.layer_dependency else 0) |
| 224 | ) |
| 225 | if self.cnt[1] <= 1: |
| 226 | seed2_contr = 0 |
| 227 | else: |
| 228 | seed2_contr = ((self.cnt[0] % self.cnt[1]) / self.cnt[1]).item() |
| 229 | frontier, importances = g.sample_labors( |
| 230 | seed_nodes, |
| 231 | fanout, |
| 232 | edge_dir=self.edge_dir, |
| 233 | prob=self.prob, |
| 234 | importance_sampling=self.importance_sampling, |
| 235 | random_seed=random_seed_i, |
| 236 | seed2_contribution=seed2_contr, |
| 237 | output_device=self.output_device, |
| 238 | exclude_edges=exclude_eids, |
| 239 | ) |
| 240 | eid = frontier.edata[EID] |
| 241 | block = to_block( |
| 242 | frontier, seed_nodes, include_dst_in_src=True, src_nodes=None |
| 243 | ) |
| 244 | block.edata[EID] = eid |
| 245 | if len(g.canonical_etypes) > 1: |
| 246 | for etype, importance in zip(g.canonical_etypes, importances): |
| 247 | if importance.shape[0] == block.num_edges(etype): |
| 248 | block.edata["edge_weights"][etype] = importance |
| 249 | elif importances[0].shape[0] == block.num_edges(): |
| 250 | block.edata["edge_weights"] = importances[0] |
| 251 | seed_nodes = block.srcdata[NID] |
| 252 | blocks.insert(0, block) |
| 253 | |
| 254 | self.set_seed() |
| 255 | return seed_nodes, output_nodes, blocks |