MCPcopy
hub / github.com/jindongwang/transferlearning / main

Function main

code/clip/main.py:35–68  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

33 return args
34
35def main(args):
36 model, dataset = args.model, args.dataset
37 model_name = ClipModel.get_model_name_by_index(model)
38 dataset_name = ImageTextData.get_data_name_by_index(dataset)
39 args.log_file = os.getcwd() + '/log/{}_{}_{}.txt'.format(args.mode, model_name, dataset_name)
40 logger = get_logger(args.log_file, args.log_file)
41 logger.info(args)
42
43 clip = ClipModel(model,logger=logger)
44 logger.info(f'Clip model {model_name} loaded')
45
46 itdata = ImageTextData(dataset, root=args.root, preprocess=clip.preprocess)
47 train_loader = torch.utils.data.DataLoader(itdata, batch_size=args.batchsize, shuffle=True)
48 logger.info(f'Dataset {dataset_name} loaded')
49
50 if args.mode == 'zs': # zeroshot
51 acc, res = clip.evaluate(train_loader)
52 logger.info('Results: {}'.format(res))
53 logger.info('Accuracy: {:.2f}%'.format(acc * 100))
54 elif args.mode == 'fe': # feature extraction
55 res = clip.feature_extraction(train_loader)
56 logger.info('Feature extracted!')
57 if not os.path.exists('feat'):
58 os.makedirs('feat')
59 feat_file = 'feat/{}_{}_{}.csv'.format(args.mode, model_name, dataset_name)
60 np.savetxt(feat_file, res, fmt='%.4f')
61 elif args.mode == 'ft': # fine-tuning
62 test_data = ImageTextData(args.test_data, root=args.root, preprocess=clip.preprocess)
63 test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batchsize, shuffle=False, drop_last=False)
64 optimizer = optim.Adam(clip.model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay)
65 best_acc = clip.finetune(train_loader, test_loader, optimizer, args.nepoch, save_path='/home/jindwang/mine/clipood/model/{}_{}_{}.pt'.format(args.mode, model_name, dataset_name))
66 logger.info('Accuracy: {:.2f}%'.format(best_acc * 100))
67 else:
68 raise NotImplementedError
69
70
71def sweep_index(model=-1, data=-1):

Callers 1

sweepFunction · 0.70

Calls 9

evaluateMethod · 0.95
feature_extractionMethod · 0.95
finetuneMethod · 0.95
get_loggerFunction · 0.90
ClipModelClass · 0.90
ImageTextDataClass · 0.90
parametersMethod · 0.45

Tested by

no test coverage detected