(ids)
| 6 | from argparse import ArgumentParser |
| 7 | |
| 8 | def get_class_labels(ids): |
| 9 | subjs = ["CU", "DA", "DR", "NI", "GU", "IA"] |
| 10 | class_map = {} |
| 11 | for i, code in enumerate(subjs): |
| 12 | with open("/dfs/scratch0/scisurv/clean/{}.tsv".format(code)) as fp: |
| 13 | fp.readline() |
| 14 | for line in fp: |
| 15 | class_map[int(line.split()[0])] = i |
| 16 | classes = [class_map[i] for i in ids] |
| 17 | return classes |
| 18 | |
| 19 | def run_regression(train_embeds, train_labels, test_embeds, test_labels): |
| 20 | np.random.seed(1) |