(opt)
| 2 | |
| 3 | |
| 4 | def build_linear(opt): |
| 5 | n_class = opt.n_class |
| 6 | arch = opt.arch |
| 7 | if arch.endswith('x4'): |
| 8 | n_feat = 2048 * 4 |
| 9 | elif arch.endswith('x2'): |
| 10 | n_feat = 2048 * 2 |
| 11 | else: |
| 12 | n_feat = 2048 |
| 13 | |
| 14 | classifier = nn.Linear(n_feat, n_class) |
| 15 | return classifier |