(f, force_fp16=False)
| 20 | #---------------------------------------------------------------------------- |
| 21 | |
| 22 | def load_network_pkl(f, force_fp16=False): |
| 23 | data = _LegacyUnpickler(f).load() |
| 24 | |
| 25 | # Legacy TensorFlow pickle => convert. |
| 26 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): |
| 27 | tf_G, tf_D, tf_Gs = data |
| 28 | G = convert_tf_generator(tf_G) |
| 29 | D = convert_tf_discriminator(tf_D) |
| 30 | G_ema = convert_tf_generator(tf_Gs) |
| 31 | data = dict(G=G, D=D, G_ema=G_ema) |
| 32 | |
| 33 | # Add missing fields. |
| 34 | if 'training_set_kwargs' not in data: |
| 35 | data['training_set_kwargs'] = None |
| 36 | if 'augment_pipe' not in data: |
| 37 | data['augment_pipe'] = None |
| 38 | |
| 39 | # Validate contents. |
| 40 | assert isinstance(data['G'], torch.nn.Module) |
| 41 | assert isinstance(data['D'], torch.nn.Module) |
| 42 | assert isinstance(data['G_ema'], torch.nn.Module) |
| 43 | assert isinstance(data['training_set_kwargs'], (dict, type(None))) |
| 44 | assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) |
| 45 | |
| 46 | # Force FP16. |
| 47 | if force_fp16: |
| 48 | for key in ['G', 'D', 'G_ema']: |
| 49 | old = data[key] |
| 50 | kwargs = copy.deepcopy(old.init_kwargs) |
| 51 | fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) |
| 52 | fp16_kwargs.num_fp16_res = 4 |
| 53 | fp16_kwargs.conv_clamp = 256 |
| 54 | if kwargs != old.init_kwargs: |
| 55 | new = type(old)(**kwargs).eval().requires_grad_(False) |
| 56 | misc.copy_params_and_buffers(old, new, require_all=True) |
| 57 | data[key] = new |
| 58 | return data |
| 59 | |
| 60 | #---------------------------------------------------------------------------- |
| 61 |
no test coverage detected