(args)
| 27 | |
| 28 | |
| 29 | def train_valid_target_eval_names(args): |
| 30 | eval_name_dict = {'train': [], 'valid': [], 'target': []} |
| 31 | t = 0 |
| 32 | for i in range(args.domain_num): |
| 33 | if i not in args.test_envs: |
| 34 | eval_name_dict['train'].append(t) |
| 35 | t += 1 |
| 36 | for i in range(args.domain_num): |
| 37 | if i not in args.test_envs: |
| 38 | eval_name_dict['valid'].append(t) |
| 39 | else: |
| 40 | eval_name_dict['target'].append(t) |
| 41 | t += 1 |
| 42 | return eval_name_dict |
| 43 | |
| 44 | |
| 45 | def alg_loss_dict(args): |