MCPcopy
hub / github.com/dmlc/dgl / process

Method process

python/dgl/data/actor.py:70–118  ·  view source on GitHub ↗

Load and process the data.

(self)

Source from the content-addressed store, hash-verified

68 )
69
70 def process(self):
71 """Load and process the data."""
72 try:
73 import torch
74 except ImportError:
75 raise ModuleNotFoundError(
76 "This dataset requires PyTorch to be the backend."
77 )
78
79 # Process node features and labels.
80 with open(f"{self.raw_path}/out1_node_feature_label.txt", "r") as f:
81 data = [x.split("\t") for x in f.read().split("\n")[1:-1]]
82
83 rows, cols = [], []
84 labels = torch.empty(len(data), dtype=torch.long)
85 for n_id, col, label in data:
86 col = [int(x) for x in col.split(",")]
87 rows += [int(n_id)] * len(col)
88 cols += col
89
90 labels[int(n_id)] = int(label)
91
92 row, col = torch.tensor(rows), torch.tensor(cols)
93 features = torch.zeros(len(data), int(col.max()) + 1)
94 features[row, col] = 1.0
95
96 self._num_classes = int(labels.max().item()) + 1
97
98 # Process graph structure.
99 with open(f"{self.raw_path}/out1_graph_edges.txt", "r") as f:
100 data = f.read().split("\n")[1:-1]
101 data = [[int(v) for v in r.split("\t")] for r in data]
102 dst, src = torch.tensor(data, dtype=torch.long).t().contiguous()
103
104 self._g = graph((src, dst), num_nodes=features.size(0))
105 self._g.ndata["feat"] = features
106 self._g.ndata["label"] = labels
107
108 # Process 10 train/val/test node splits.
109 train_masks, val_masks, test_masks = [], [], []
110 for i in range(10):
111 filepath = f"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz"
112 f = np.load(filepath)
113 train_masks += [torch.from_numpy(f["train_mask"])]
114 val_masks += [torch.from_numpy(f["val_mask"])]
115 test_masks += [torch.from_numpy(f["test_mask"])]
116 self._g.ndata["train_mask"] = torch.stack(train_masks, dim=1).bool()
117 self._g.ndata["val_mask"] = torch.stack(val_masks, dim=1).bool()
118 self._g.ndata["test_mask"] = torch.stack(test_masks, dim=1).bool()
119
120 def has_cache(self):
121 return os.path.exists(self.raw_path)

Callers 1

loadMethod · 0.95

Calls 5

graphFunction · 0.85
tMethod · 0.80
readMethod · 0.45
sizeMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected