This function creates a sample dataset, based on MAG240 dataset. Parameters: ----------- root_dir : string directory in which all the files for the chunked dataset will be stored.
(
root_dir,
num_chunks,
data_fmt="numpy",
edges_fmt="csv",
vector_rows=False,
**kwargs,
)
| 340 | |
| 341 | |
| 342 | def create_chunked_dataset( |
| 343 | root_dir, |
| 344 | num_chunks, |
| 345 | data_fmt="numpy", |
| 346 | edges_fmt="csv", |
| 347 | vector_rows=False, |
| 348 | **kwargs, |
| 349 | ): |
| 350 | """ |
| 351 | This function creates a sample dataset, based on MAG240 dataset. |
| 352 | |
| 353 | Parameters: |
| 354 | ----------- |
| 355 | root_dir : string |
| 356 | directory in which all the files for the chunked dataset will be stored. |
| 357 | """ |
| 358 | # Step0: prepare chunked graph data format. |
| 359 | # A synthetic mini MAG240. |
| 360 | num_institutions = 1200 |
| 361 | num_authors = 1200 |
| 362 | num_papers = 1200 |
| 363 | |
| 364 | def rand_edges(num_src, num_dst, num_edges): |
| 365 | eids = np.random.choice(num_src * num_dst, num_edges, replace=False) |
| 366 | src = torch.from_numpy(eids // num_dst) |
| 367 | dst = torch.from_numpy(eids % num_dst) |
| 368 | |
| 369 | return src, dst |
| 370 | |
| 371 | num_cite_edges = 24 * 1000 |
| 372 | num_write_edges = 12 * 1000 |
| 373 | num_affiliate_edges = 2400 |
| 374 | |
| 375 | # Structure. |
| 376 | data_dict = { |
| 377 | ("paper", "cites", "paper"): rand_edges( |
| 378 | num_papers, num_papers, num_cite_edges |
| 379 | ), |
| 380 | ("author", "writes", "paper"): rand_edges( |
| 381 | num_authors, num_papers, num_write_edges |
| 382 | ), |
| 383 | ("author", "affiliated_with", "institution"): rand_edges( |
| 384 | num_authors, num_institutions, num_affiliate_edges |
| 385 | ), |
| 386 | ("institution", "writes", "paper"): rand_edges( |
| 387 | num_institutions, num_papers, num_write_edges |
| 388 | ), |
| 389 | } |
| 390 | src, dst = data_dict[("author", "writes", "paper")] |
| 391 | data_dict[("paper", "rev_writes", "author")] = (dst, src) |
| 392 | g = dgl.heterograph(data_dict) |
| 393 | |
| 394 | # paper feat, label, year |
| 395 | num_paper_feats = 3 |
| 396 | paper_feat = np.random.randn(num_papers, num_paper_feats) |
| 397 | num_classes = 4 |
| 398 | paper_label = np.random.choice(num_classes, num_papers) |
| 399 | paper_year = np.random.choice(2022, num_papers) |
no test coverage detected