(network, path='models')
| 177 | |
| 178 | |
| 179 | def restore_params(network, path='models'): |
| 180 | logging.info("Restore pre-trained parameters") |
| 181 | maybe_download_and_extract( |
| 182 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', |
| 183 | path, |
| 184 | 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/', |
| 185 | ) # ls -al |
| 186 | try: |
| 187 | import h5py |
| 188 | except Exception: |
| 189 | raise ImportError('h5py not imported') |
| 190 | |
| 191 | f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r') |
| 192 | |
| 193 | for layer in network.all_layers: |
| 194 | if len(layer.all_weights) == 0: |
| 195 | continue |
| 196 | w_names = list(f[layer.name]) |
| 197 | params = [f[layer.name][n][:] for n in w_names] |
| 198 | # if 'bn' in layer.name: |
| 199 | # params = [x.reshape(1, 1, 1, -1) for x in params] |
| 200 | assign_weights(params, layer) |
| 201 | del params |
| 202 | |
| 203 | f.close() |
no test coverage detected
searching dependent graphs…