(model_config)
| 255 | |
| 256 | |
| 257 | def static_graph2net(model_config): |
| 258 | layer_dict = {} |
| 259 | model_name = model_config["name"] |
| 260 | inputs_tensors = model_config["inputs"] |
| 261 | outputs_tensors = model_config["outputs"] |
| 262 | all_args = model_config["model_architecture"] |
| 263 | for idx, layer_kwargs in enumerate(all_args): |
| 264 | layer_class = layer_kwargs["class"] # class of current layer |
| 265 | prev_layers = layer_kwargs.pop("prev_layer") # name of previous layers |
| 266 | net = eval_layer(layer_kwargs) |
| 267 | if layer_class in tl.layers.inputs.__all__: |
| 268 | net = net._nodes[0].out_tensors[0] |
| 269 | if prev_layers is not None: |
| 270 | for prev_layer in prev_layers: |
| 271 | if not isinstance(prev_layer, list): |
| 272 | output = net(layer_dict[prev_layer]) |
| 273 | layer_dict[output._info[0].name] = output |
| 274 | else: |
| 275 | list_layers = [layer_dict[layer] for layer in prev_layer] |
| 276 | output = net(list_layers) |
| 277 | layer_dict[output._info[0].name] = output |
| 278 | else: |
| 279 | layer_dict[net._info[0].name] = net |
| 280 | |
| 281 | if not isinstance(inputs_tensors, list): |
| 282 | model_inputs = layer_dict[inputs_tensors] |
| 283 | else: |
| 284 | model_inputs = [] |
| 285 | for inputs_tensor in inputs_tensors: |
| 286 | model_inputs.append(layer_dict[inputs_tensor]) |
| 287 | if not isinstance(outputs_tensors, list): |
| 288 | model_outputs = layer_dict[outputs_tensors] |
| 289 | else: |
| 290 | model_outputs = [] |
| 291 | for outputs_tensor in outputs_tensors: |
| 292 | model_outputs.append(layer_dict[outputs_tensor]) |
| 293 | from tensorlayer.models import Model |
| 294 | M = Model(inputs=model_inputs, outputs=model_outputs, name=model_name) |
| 295 | logging.info("[*] Load graph finished") |
| 296 | return M |
| 297 | |
| 298 | |
| 299 | def load_hdf5_graph(filepath='model.hdf5', load_weights=False): |
no test coverage detected
searching dependent graphs…