MCPcopy
hub / github.com/THUDM/CogDL / evaluate

Function evaluate

examples/graphmae2/main_large.py:29–87  ·  view source on GitHub ↗
(
    model, 
    graph, feats, labels, 
    split_idx, 
    lr_f, weight_decay_f, max_epoch_f, 
    linear_prob=True, 
    device=0, 
    batch_size=256, 
    logger=None, ego_graph_nodes=None, 
    label_rate=1.0,
    full_graph_forward=False,
    shuffle=True,
)

Source from the content-addressed store, hash-verified

27
28
29def evaluate(
30 model,
31 graph, feats, labels,
32 split_idx,
33 lr_f, weight_decay_f, max_epoch_f,
34 linear_prob=True,
35 device=0,
36 batch_size=256,
37 logger=None, ego_graph_nodes=None,
38 label_rate=1.0,
39 full_graph_forward=False,
40 shuffle=True,
41):
42 logging.info("Using `lc` for evaluation...")
43 num_train, num_val, num_test = [split_idx[k].shape[0] for k in ["train", "valid", "test"]]
44 print(num_train,num_val,num_test)
45
46 train_g_idx = np.arange(0, num_train)
47 val_g_idx = np.arange(num_train, num_train+num_val)
48 test_g_idx = np.arange(num_train+num_val, num_train+num_val+num_test)
49
50 train_ego_graph_nodes = [ego_graph_nodes[i] for i in train_g_idx]
51 val_ego_graph_nodes = [ego_graph_nodes[i] for i in val_g_idx]
52 test_ego_graph_nodes = [ego_graph_nodes[i] for i in test_g_idx]
53
54 train_lbls, val_lbls, test_lbls = labels[train_g_idx], labels[val_g_idx], labels[test_g_idx]
55
56 # labels = [train_lbls, val_lbls, test_lbls]
57 assert len(train_ego_graph_nodes) == len(train_lbls)
58 assert len(val_ego_graph_nodes) == len(val_lbls)
59 assert len(test_ego_graph_nodes) == len(test_lbls)
60
61 print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
62 logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
63
64
65 if linear_prob:
66 result = linear_probing_minibatch(model, graph, feats, [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes], [train_lbls, val_lbls, test_lbls], lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, batch_size=batch_size, device=device, shuffle=shuffle)
67 else:
68 max_epoch_f = max_epoch_f // 2
69
70 if label_rate < 1.0:
71 rand_idx = np.arange(len(train_ego_graph_nodes))
72 np.random.shuffle(rand_idx)
73 rand_idx = rand_idx[:int(label_rate * len(train_ego_graph_nodes))]
74 train_ego_graph_nodes = [train_ego_graph_nodes[i] for i in rand_idx]
75 train_lbls = train_lbls[rand_idx]
76
77 logging.info(f"-- train_ego_nodes:{len(train_ego_graph_nodes)}, val_ego_nodes:{len(val_ego_graph_nodes)}, test_ego_nodes:{len(test_ego_graph_nodes)} ---")
78
79 # train_lbls = (all_train_lbls, train_lbls)
80 result = finetune(
81 model, graph, feats,
82 [train_ego_graph_nodes, val_ego_graph_nodes, test_ego_graph_nodes],
83 [train_lbls, val_lbls, test_lbls],
84 split_idx=split_idx,
85 lr_f=lr_f, weight_decay_f=weight_decay_f, max_epoch_f=max_epoch_f, use_scheduler=True, batch_size=batch_size, device=device, logger=logger, full_graph_forward=full_graph_forward,
86 )

Callers 3

main_large.pyFile · 0.70
modificationMethod · 0.50
sample_final_edgesMethod · 0.50

Calls 3

linear_probing_minibatchFunction · 0.90
finetuneFunction · 0.90
shuffleMethod · 0.45

Tested by

no test coverage detected