(layer_kwargs)
| 226 | |
| 227 | |
| 228 | def eval_layer(layer_kwargs): |
| 229 | layer_class = layer_kwargs.pop('class') |
| 230 | args = layer_kwargs['args'] |
| 231 | layer_type = args.pop('layer_type') |
| 232 | if layer_type == "normal": |
| 233 | generate_func(args) |
| 234 | return eval('tl.layers.' + layer_class)(**args) |
| 235 | elif layer_type == "layerlist": |
| 236 | ret_layer = [] |
| 237 | layers = args["layers"] |
| 238 | for layer_graph in layers: |
| 239 | ret_layer.append(eval_layer(layer_graph)) |
| 240 | args['layers'] = ret_layer |
| 241 | return eval('tl.layers.' + layer_class)(**args) |
| 242 | elif layer_type == "modellayer": |
| 243 | M = static_graph2net(args['model']) |
| 244 | args['model'] = M |
| 245 | return eval('tl.layers.' + layer_class)(**args) |
| 246 | elif layer_type == "keraslayer": |
| 247 | M = load_keras_model(args['fn']) |
| 248 | input_shape = args.pop('keras_input_shape') |
| 249 | _ = M(np.random.random(input_shape).astype(np.float32)) |
| 250 | args['fn'] = M |
| 251 | args['fn_weights'] = M.trainable_variables |
| 252 | return eval('tl.layers.' + layer_class)(**args) |
| 253 | else: |
| 254 | raise RuntimeError("Unknown layer type.") |
| 255 | |
| 256 | |
| 257 | def static_graph2net(model_config): |
no test coverage detected
searching dependent graphs…