| 94 | |
| 95 | |
| 96 | def img_param_init(args): |
| 97 | dataset = args.dataset |
| 98 | if dataset == 'office': |
| 99 | domains = ['amazon', 'dslr', 'webcam'] |
| 100 | elif dataset == 'office-caltech': |
| 101 | domains = ['amazon', 'dslr', 'webcam', 'caltech'] |
| 102 | elif dataset == 'office-home': |
| 103 | domains = ['Art', 'Clipart', 'Product', 'Real_World'] |
| 104 | elif dataset == 'dg5': |
| 105 | domains = ['mnist', 'mnist_m', 'svhn', 'syn', 'usps'] |
| 106 | elif dataset == 'PACS': |
| 107 | domains = ['art_painting', 'cartoon', 'photo', 'sketch'] |
| 108 | elif dataset == 'VLCS': |
| 109 | domains = ['Caltech101', 'LabelMe', 'SUN09', 'VOC2007'] |
| 110 | else: |
| 111 | print('No such dataset exists!') |
| 112 | args.domains = domains |
| 113 | args.img_dataset = { |
| 114 | 'office': ['amazon', 'dslr', 'webcam'], |
| 115 | 'office-caltech': ['amazon', 'dslr', 'webcam', 'caltech'], |
| 116 | 'office-home': ['Art', 'Clipart', 'Product', 'Real_World'], |
| 117 | 'PACS': ['art_painting', 'cartoon', 'photo', 'sketch'], |
| 118 | 'dg5': ['mnist', 'mnist_m', 'svhn', 'syn', 'usps'], |
| 119 | 'VLCS': ['Caltech101', 'LabelMe', 'SUN09', 'VOC2007'] |
| 120 | } |
| 121 | if dataset == 'dg5': |
| 122 | args.input_shape = (3, 32, 32) |
| 123 | args.num_classes = 10 |
| 124 | else: |
| 125 | args.input_shape = (3, 224, 224) |
| 126 | if args.dataset == 'office-home': |
| 127 | args.num_classes = 65 |
| 128 | elif args.dataset == 'office': |
| 129 | args.num_classes = 31 |
| 130 | elif args.dataset == 'PACS': |
| 131 | args.num_classes = 7 |
| 132 | elif args.dataset == 'VLCS': |
| 133 | args.num_classes = 5 |
| 134 | return args |