(args)
| 43 | |
| 44 | |
| 45 | def alg_loss_dict(args): |
| 46 | loss_dict = {'ANDMask': ['total'], |
| 47 | 'CORAL': ['class', 'coral', 'total'], |
| 48 | 'DANN': ['class', 'dis', 'total'], |
| 49 | 'ERM': ['class'], |
| 50 | 'Mixup': ['class'], |
| 51 | 'MLDG': ['total'], |
| 52 | 'MMD': ['class', 'mmd', 'total'], |
| 53 | 'GroupDRO': ['group'], |
| 54 | 'RSC': ['class'], |
| 55 | 'VREx': ['loss', 'nll', 'penalty'], |
| 56 | 'DIFEX': ['class', 'dist', 'exp', 'align', 'total'] |
| 57 | } |
| 58 | return loss_dict[args.algorithm] |
| 59 | |
| 60 | |
| 61 | def print_args(args, print_list): |