MCPcopy Index your code
hub / github.com/google-deepmind/learning-to-learn / ensemble

Function ensemble

problems.py:102–132  ·  view source on GitHub ↗

Ensemble of problems. Args: problems: List of problems. Each problem is specified by a dict containing the keys 'name' and 'options'. weights: Optional list of weights for each problem. Returns: Sum of (weighted) losses. Raises: ValueError: If weights has an incorrec

(problems, weights=None)

Source from the content-addressed store, hash-verified

100
101
102def ensemble(problems, weights=None):
103 """Ensemble of problems.
104
105 Args:
106 problems: List of problems. Each problem is specified by a dict containing
107 the keys 'name' and 'options'.
108 weights: Optional list of weights for each problem.
109
110 Returns:
111 Sum of (weighted) losses.
112
113 Raises:
114 ValueError: If weights has an incorrect length.
115 """
116 if weights and len(weights) != len(problems):
117 raise ValueError("len(weights) != len(problems)")
118
119 build_fns = [getattr(sys.modules[__name__], p["name"])(**p["options"])
120 for p in problems]
121
122 def build():
123 loss = 0
124 for i, build_fn in enumerate(build_fns):
125 with tf.variable_scope("problem_{}".format(i)):
126 loss_p = build_fn()
127 if weights:
128 loss_p *= weights[i]
129 loss += loss_p
130 return loss
131
132 return build
133
134
135def _xent_loss(output, labels):

Callers 3

testShapeMethod · 0.85
testVariablesMethod · 0.85
testValuesMethod · 0.85

Calls

no outgoing calls

Tested by 3

testShapeMethod · 0.68
testVariablesMethod · 0.68
testValuesMethod · 0.68