MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / wrapped_func

Function wrapped_func

tensorpack/models/registry.py:128–181  ·  view source on GitHub ↗
(*args, **kwargs)

Source from the content-addressed store, hash-verified

126 def wrapper(func):
127 @wraps(func)
128 def wrapped_func(*args, **kwargs):
129 assert args[0] is not None, args
130 if use_scope:
131 name, inputs = args[0], args[1]
132 args = args[1:] # actual positional args used to call func
133 assert isinstance(name, six.string_types), "First argument for \"{}\" should be a string. ".format(
134 func.__name__) + "Did you forget to specify the name of the layer?"
135 else:
136 assert not log_shape
137 if isinstance(args[0], six.string_types):
138 if use_scope is False:
139 logger.warn(
140 "Please call layer {} without the first scope name argument, "
141 "or register the layer with use_scope=None to allow calling it "
142 "with scope names.".format(func.__name__))
143 name, inputs = args[0], args[1]
144 args = args[1:] # actual positional args used to call func
145 else:
146 inputs = args[0]
147 name = None
148 if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or
149 (isinstance(inputs, (list, tuple)) and
150 isinstance(inputs[0], (tf.Tensor, tf.Variable)))):
151 raise ValueError("Invalid inputs to layer: " + str(inputs))
152
153 # use kwargs from current argument scope
154 actual_args = copy.copy(get_arg_scope()[func.__name__])
155 # explicit kwargs overwrite argscope
156 actual_args.update(kwargs)
157 # if six.PY3:
158 # # explicit positional args also override argscope. only work in PY3
159 # posargmap = inspect.signature(func).bind_partial(*args).arguments
160 # for k in six.iterkeys(posargmap):
161 # if k in actual_args:
162 # del actual_args[k]
163
164 if name is not None: # use scope
165 with tfv1.variable_scope(name) as scope:
166 # this name is only used to surpress logging, doesn't hurt to do some heuristics
167 scope_name = re.sub('tower[0-9]+/', '', scope.name)
168 do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
169 if do_log_shape:
170 _SHAPE_LOGGER.push_inputs(scope.name, get_shape_str(inputs))
171
172 # run the actual function
173 outputs = func(*args, **actual_args)
174
175 if do_log_shape:
176 _SHAPE_LOGGER.push_outputs(scope.name, get_shape_str(outputs))
177 _LAYER_LOGGED.add(scope_name)
178 else:
179 # run the actual function
180 outputs = func(*args, **actual_args)
181 return outputs
182
183 wrapped_func.use_scope = use_scope
184 wrapped_func.__argscope_enabled__ = True

Callers

nothing calls this directly

Calls 7

get_arg_scopeFunction · 0.85
get_shape_strFunction · 0.85
formatMethod · 0.80
updateMethod · 0.80
push_inputsMethod · 0.80
push_outputsMethod · 0.80
addMethod · 0.45

Tested by

no test coverage detected