MCPcopy
hub / github.com/dmlc/dgl / process

Method process

python/dgl/data/reddit.py:108–141  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

106 )
107
108 def process(self):
109 # graph
110 coo_adj = sp.load_npz(
111 os.path.join(
112 self.raw_path, "reddit{}_graph.npz".format(self._self_loop_str)
113 )
114 )
115 self._graph = from_scipy(coo_adj)
116 # features and labels
117 reddit_data = np.load(os.path.join(self.raw_path, "reddit_data.npz"))
118 features = reddit_data["feature"]
119 labels = reddit_data["label"]
120 # tarin/val/test indices
121 node_types = reddit_data["node_types"]
122 train_mask = node_types == 1
123 val_mask = node_types == 2
124 test_mask = node_types == 3
125 self._graph.ndata["train_mask"] = generate_mask_tensor(train_mask)
126 self._graph.ndata["val_mask"] = generate_mask_tensor(val_mask)
127 self._graph.ndata["test_mask"] = generate_mask_tensor(test_mask)
128 self._graph.ndata["feat"] = F.tensor(
129 features, dtype=F.data_type_dict["float32"]
130 )
131 self._graph.ndata["label"] = F.tensor(
132 labels, dtype=F.data_type_dict["int64"]
133 )
134 self._graph = reorder_graph(
135 self._graph,
136 node_permute_algo="rcmk",
137 edge_permute_algo="dst",
138 store_ids=False,
139 )
140
141 self._print_info()
142
143 def has_cache(self):
144 graph_path = os.path.join(self.save_path, "dgl_graph.bin")

Callers

nothing calls this directly

Calls 7

_print_infoMethod · 0.95
from_scipyFunction · 0.85
generate_mask_tensorFunction · 0.85
reorder_graphFunction · 0.85
formatMethod · 0.80
joinMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected