(args)
| 33 | return args |
| 34 | |
| 35 | def main(args): |
| 36 | model, dataset = args.model, args.dataset |
| 37 | model_name = ClipModel.get_model_name_by_index(model) |
| 38 | dataset_name = ImageTextData.get_data_name_by_index(dataset) |
| 39 | args.log_file = os.getcwd() + '/log/{}_{}_{}.txt'.format(args.mode, model_name, dataset_name) |
| 40 | logger = get_logger(args.log_file, args.log_file) |
| 41 | logger.info(args) |
| 42 | |
| 43 | clip = ClipModel(model,logger=logger) |
| 44 | logger.info(f'Clip model {model_name} loaded') |
| 45 | |
| 46 | itdata = ImageTextData(dataset, root=args.root, preprocess=clip.preprocess) |
| 47 | train_loader = torch.utils.data.DataLoader(itdata, batch_size=args.batchsize, shuffle=True) |
| 48 | logger.info(f'Dataset {dataset_name} loaded') |
| 49 | |
| 50 | if args.mode == 'zs': # zeroshot |
| 51 | acc, res = clip.evaluate(train_loader) |
| 52 | logger.info('Results: {}'.format(res)) |
| 53 | logger.info('Accuracy: {:.2f}%'.format(acc * 100)) |
| 54 | elif args.mode == 'fe': # feature extraction |
| 55 | res = clip.feature_extraction(train_loader) |
| 56 | logger.info('Feature extracted!') |
| 57 | if not os.path.exists('feat'): |
| 58 | os.makedirs('feat') |
| 59 | feat_file = 'feat/{}_{}_{}.csv'.format(args.mode, model_name, dataset_name) |
| 60 | np.savetxt(feat_file, res, fmt='%.4f') |
| 61 | elif args.mode == 'ft': # fine-tuning |
| 62 | test_data = ImageTextData(args.test_data, root=args.root, preprocess=clip.preprocess) |
| 63 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batchsize, shuffle=False, drop_last=False) |
| 64 | optimizer = optim.Adam(clip.model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay) |
| 65 | best_acc = clip.finetune(train_loader, test_loader, optimizer, args.nepoch, save_path='/home/jindwang/mine/clipood/model/{}_{}_{}.pt'.format(args.mode, model_name, dataset_name)) |
| 66 | logger.info('Accuracy: {:.2f}%'.format(best_acc * 100)) |
| 67 | else: |
| 68 | raise NotImplementedError |
| 69 | |
| 70 | |
| 71 | def sweep_index(model=-1, data=-1): |
no test coverage detected