(model, layer_type)
| 155 | |
| 156 | |
| 157 | def restore_model(model, layer_type): |
| 158 | logging.info("Restore pre-trained weights") |
| 159 | # download weights |
| 160 | maybe_download_and_extract(model_saved_name[layer_type], 'models', model_urls[layer_type]) |
| 161 | weights = [] |
| 162 | if layer_type == 'vgg16': |
| 163 | npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True) |
| 164 | # get weight list |
| 165 | for val in sorted(npz.items()): |
| 166 | logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0])) |
| 167 | weights.append(val[1]) |
| 168 | if len(model.all_weights) == len(weights): |
| 169 | break |
| 170 | elif layer_type == 'vgg19': |
| 171 | npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item() |
| 172 | # get weight list |
| 173 | for val in sorted(npz.items()): |
| 174 | logging.info(" Loading %s in %s" % (str(val[1][0].shape), val[0])) |
| 175 | logging.info(" Loading %s in %s" % (str(val[1][1].shape), val[0])) |
| 176 | weights.extend(val[1]) |
| 177 | if len(model.all_weights) == len(weights): |
| 178 | break |
| 179 | # assign weight values |
| 180 | assign_weights(weights, model) |
| 181 | del weights |
| 182 | |
| 183 | |
| 184 | def VGG_static(layer_type, batch_norm=False, end_with='outputs', name=None): |
no test coverage detected
searching dependent graphs…