(
model,
graph, feats, labels,
split_idx,
lr_f, weight_decay_f, max_epoch_f,
linear_prob=True,
device=0,
batch_size=256,
logger=None, ego_graph_nodes=None,
label_rate=1.0,
full_graph_forward=False,
shuffle=True,
)
| 27 | |
| 28 | |
| 29 | def evaluate( |
| 30 | model, |
| 31 | graph, feats, labels, |
| 32 | split_idx, |
| 33 | lr_f, weight_decay_f, max_epoch_f, |
| 34 | linear_prob=True, |
| 35 | device=0, |
| 36 | batch_size=256, |
| 37 | logger=None, ego_graph_nodes=None, |
| 38 | label_rate=1.0, |
| 39 | full_graph_forward=False, |
| 40 | shuffle=True, |
| 41 | ): |
| 42 | logging.info("Using `lc` for evaluation...") |
| 43 | num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]] |
| 44 | print(num_train,num_val,num_test) |
| 45 | |
| 46 | train_g_idx = np.arange(0, num_train) |
| 47 | val_g_idx = np.arange(num_train, num_train+num_val) |
| 48 | test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test) |
| 49 | |
| 50 | train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx] |
| 51 | val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx] |
| 52 | test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx] |
| 53 | |
| 54 | train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx] |
| 55 | |
| 56 | # labels = [train_lbls, val_lbls, test_lbls] |
| 57 | assert len(train_ego_graph_nodes) == len(train_lbls) |
| 58 | assert len(val_ego_graph_nodes) == len(val_lbls) |
| 59 | assert len(test_ego_graph_nodes) == len(test_lbls) |
| 60 | |
| 61 | print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") |
| 62 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") |
| 63 | |
| 64 | |
| 65 | if linear_prob: |
| 66 | result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle) |
| 67 | else: |
| 68 | max_epoch_f = max_epoch_f // 2 |
| 69 | |
| 70 | if label_rate < 1.0: |
| 71 | rand_idx = np.arange(len(train_ego_graph_nodes)) |
| 72 | np.random.shuffle(rand_idx) |
| 73 | rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))] |
| 74 | train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx] |
| 75 | train_lbls = train_lbls[rand_idx] |
| 76 | |
| 77 | logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---") |
| 78 | |
| 79 | # train_lbls = (all_train_lbls, train_lbls) |
| 80 | result = finetune( |
| 81 | model, graph, feats, |
| 82 | [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], |
| 83 | [train_lbls, val_lbls, test_lbls], |
| 84 | split_idx=split_idx, |
| 85 | lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward, |
| 86 | ) |
no test coverage detected