MCPcopy
hub / github.com/jindongwang/transferlearning / get_params

Function get_params

code/DeepDG/alg/opt.py:5–46  ·  view source on GitHub ↗
(alg, args, inner=False, alias=True, isteacher=False)

Source from the content-addressed store, hash-verified

3
4
5def get_params(alg, args, inner=False, alias=True, isteacher=False):
6 if args.schuse:
7 if args.schusech == 'cos':
8 initlr = args.lr
9 else:
10 initlr = 1.0
11 else:
12 if inner:
13 initlr = args.inner_lr
14 else:
15 initlr = args.lr
16 if isteacher:
17 params = [
18 {'params': alg[0].parameters(), 'lr': args.lr_decay1 * initlr},
19 {'params': alg[1].parameters(), 'lr': args.lr_decay2 * initlr},
20 {'params': alg[2].parameters(), 'lr': args.lr_decay2 * initlr}
21 ]
22 return params
23 if inner:
24 params = [
25 {'params': alg[0].parameters(), 'lr': args.lr_decay1 *
26 initlr},
27 {'params': alg[1].parameters(), 'lr': args.lr_decay2 *
28 initlr}
29 ]
30 elif alias:
31 params = [
32 {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * initlr},
33 {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * initlr}
34 ]
35 else:
36 params = [
37 {'params': alg[0].parameters(), 'lr': args.lr_decay1 * initlr},
38 {'params': alg[1].parameters(), 'lr': args.lr_decay2 * initlr}
39 ]
40 if ('DANN' in args.algorithm) or ('CDANN' in args.algorithm):
41 params.append({'params': alg.discriminator.parameters(),
42 'lr': args.lr_decay2 * initlr})
43 if ('CDANN' in args.algorithm):
44 params.append({'params': alg.class_embeddings.parameters(),
45 'lr': args.lr_decay2 * initlr})
46 return params
47
48
49def get_optimizer(alg, args, inner=False, alias=True, isteacher=False):

Callers 1

get_optimizerFunction · 0.70

Calls 1

parametersMethod · 0.45

Tested by

no test coverage detected