MCPcopy
hub / github.com/THUDM/CogDL / Graph

Class Graph

cogdl/data/data.py:474–959  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

472
473
474class Graph(BaseGraph):
475 def __init__(self, x=None, y=None, **kwargs):
476 super(Graph, self).__init__()
477 if x is not None:
478 if not torch.is_tensor(x):
479 raise ValueError("Node features must be Tensor")
480 self.x = x
481 self.y = y
482 self.grb_adj = None
483 num_nodes = x.shape[0] if x is not None else None
484
485 for key, item in kwargs.items():
486 if key == "num_nodes":
487 self.__num_nodes__ = item
488 num_nodes = item
489 elif key == "grb_adj":
490 self.grb_adj = item
491 elif not is_read_adj_key(key):
492 self[key] = item
493
494 if "edge_index_train" in kwargs:
495 self._adj_train = Adjacency(num_nodes=num_nodes)
496 for key, item in kwargs.items():
497 if is_adj_key_train(key):
498 _key = re.search(r"(.*)_train", key).group(1)
499 if _key.startswith("edge_"):
500 _key = _key.split("edge_")[1]
501 if _key == "index":
502 self._adj_train.edge_index = item
503 else:
504 self._adj_train[_key] = item
505 else:
506 self._adj_train = None
507
508 self._adj_full = Adjacency(num_nodes=num_nodes)
509 for key, item in kwargs.items():
510 if is_read_adj_key(key) and not is_adj_key_train(key):
511 if key.startswith("edge_"):
512 key = key.split("edge_")[-1]
513 if key == "index":
514 self._adj_full.edge_index = item
515 else:
516 self._adj_full[key] = item
517
518 self._adj = self._adj_full
519 self.__is_train__ = False
520 self.__temp_adj_stack__ = list()
521 self.__temp_storage__ = dict()
522
523 def train(self):
524 self.__is_train__ = True
525 if self._adj_train is not None:
526 self._adj = self._adj_train
527 return self
528
529 def eval(self):
530 self._adj = self._adj_full
531 self.__is_train__ = False

Callers 15

test_base_layerFunction · 0.90
test_gine_layerFunction · 0.90
build_toy_dataFunction · 0.90
processMethod · 0.90
processMethod · 0.90
processMethod · 0.90
processMethod · 0.90
1graph.pyFile · 0.90
1graph_cn.pyFile · 0.90

Calls

no outgoing calls