MCPcopy
hub / github.com/XingangPan/DragGAN / load_network_pkl

Function load_network_pkl

legacy.py:22–58  ·  view source on GitHub ↗
(f, force_fp16=False)

Source from the content-addressed store, hash-verified

20#----------------------------------------------------------------------------
21
22def load_network_pkl(f, force_fp16=False):
23 data = _LegacyUnpickler(f).load()
24
25 # Legacy TensorFlow pickle => convert.
26 if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27 tf_G, tf_D, tf_Gs = data
28 G = convert_tf_generator(tf_G)
29 D = convert_tf_discriminator(tf_D)
30 G_ema = convert_tf_generator(tf_Gs)
31 data = dict(G=G, D=D, G_ema=G_ema)
32
33 # Add missing fields.
34 if 'training_set_kwargs' not in data:
35 data['training_set_kwargs'] = None
36 if 'augment_pipe' not in data:
37 data['augment_pipe'] = None
38
39 # Validate contents.
40 assert isinstance(data['G'], torch.nn.Module)
41 assert isinstance(data['D'], torch.nn.Module)
42 assert isinstance(data['G_ema'], torch.nn.Module)
43 assert isinstance(data['training_set_kwargs'], (dict, type(None)))
44 assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
45
46 # Force FP16.
47 if force_fp16:
48 for key in ['G', 'D', 'G_ema']:
49 old = data[key]
50 kwargs = copy.deepcopy(old.init_kwargs)
51 fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
52 fp16_kwargs.num_fp16_res = 4
53 fp16_kwargs.conv_clamp = 256
54 if kwargs != old.init_kwargs:
55 new = type(old)(**kwargs).eval().requires_grad_(False)
56 misc.copy_params_and_buffers(old, new, require_all=True)
57 data[key] = new
58 return data
59
60#----------------------------------------------------------------------------
61

Callers 1

convert_network_pickleFunction · 0.70

Calls 4

convert_tf_generatorFunction · 0.85
convert_tf_discriminatorFunction · 0.85
loadMethod · 0.80
_LegacyUnpicklerClass · 0.70

Tested by

no test coverage detected