MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / rename_get_variable

Function rename_get_variable

tensorpack/models/tflayer.py:73–89  ·  view source on GitHub ↗

Args: mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} Returns: A context where the variables are renamed.

(mapping)

Source from the content-addressed store, hash-verified

71
72
73def rename_get_variable(mapping):
74 """
75 Args:
76 mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}
77
78 Returns:
79 A context where the variables are renamed.
80 """
81 def custom_getter(getter, name, *args, **kwargs):
82 splits = name.split('/')
83 basename = splits[-1]
84 if basename in mapping:
85 basename = mapping[basename]
86 splits[-1] = basename
87 name = '/'.join(splits)
88 return getter(name, *args, **kwargs)
89 return custom_getter_scope(custom_getter)
90
91
92def rename_tflayer_get_variable():

Callers 5

BatchNormFunction · 0.85
Conv2DFunction · 0.85
Conv2DTransposeFunction · 0.85
FullyConnectedFunction · 0.85

Calls 1

custom_getter_scopeFunction · 0.85

Tested by

no test coverage detected