| 472 | |
| 473 | |
| 474 | class Graph(BaseGraph): |
| 475 | def __init__(self, x=None, y=None, **kwargs): |
| 476 | super(Graph, self).__init__() |
| 477 | if x is not None: |
| 478 | if not torch.is_tensor(x): |
| 479 | raise ValueError("Node features must be Tensor") |
| 480 | self.x = x |
| 481 | self.y = y |
| 482 | self.grb_adj = None |
| 483 | num_nodes = x.shape[0] if x is not None else None |
| 484 | |
| 485 | for key, item in kwargs.items(): |
| 486 | if key == "num_nodes": |
| 487 | self.__num_nodes__ = item |
| 488 | num_nodes = item |
| 489 | elif key == "grb_adj": |
| 490 | self.grb_adj = item |
| 491 | elif not is_read_adj_key(key): |
| 492 | self[key] = item |
| 493 | |
| 494 | if "edge_index_train" in kwargs: |
| 495 | self._adj_train = Adjacency(num_nodes=num_nodes) |
| 496 | for key, item in kwargs.items(): |
| 497 | if is_adj_key_train(key): |
| 498 | _key = re.search(r"(.*)_train", key).group(1) |
| 499 | if _key.startswith("edge_"): |
| 500 | _key = _key.split("edge_")[1] |
| 501 | if _key == "index": |
| 502 | self._adj_train.edge_index = item |
| 503 | else: |
| 504 | self._adj_train[_key] = item |
| 505 | else: |
| 506 | self._adj_train = None |
| 507 | |
| 508 | self._adj_full = Adjacency(num_nodes=num_nodes) |
| 509 | for key, item in kwargs.items(): |
| 510 | if is_read_adj_key(key) and not is_adj_key_train(key): |
| 511 | if key.startswith("edge_"): |
| 512 | key = key.split("edge_")[-1] |
| 513 | if key == "index": |
| 514 | self._adj_full.edge_index = item |
| 515 | else: |
| 516 | self._adj_full[key] = item |
| 517 | |
| 518 | self._adj = self._adj_full |
| 519 | self.__is_train__ = False |
| 520 | self.__temp_adj_stack__ = list() |
| 521 | self.__temp_storage__ = dict() |
| 522 | |
| 523 | def train(self): |
| 524 | self.__is_train__ = True |
| 525 | if self._adj_train is not None: |
| 526 | self._adj = self._adj_train |
| 527 | return self |
| 528 | |
| 529 | def eval(self): |
| 530 | self._adj = self._adj_full |
| 531 | self.__is_train__ = False |
no outgoing calls