(*args, **kwargs)
| 126 | def wrapper(func): |
| 127 | @wraps(func) |
| 128 | def wrapped_func(*args, **kwargs): |
| 129 | assert args[0] is not None, args |
| 130 | if use_scope: |
| 131 | name, inputs = args[0], args[1] |
| 132 | args = args[1:] # actual positional args used to call func |
| 133 | assert isinstance(name, six.string_types), "First argument for \"{}\" should be a string. ".format( |
| 134 | func.__name__) + "Did you forget to specify the name of the layer?" |
| 135 | else: |
| 136 | assert not log_shape |
| 137 | if isinstance(args[0], six.string_types): |
| 138 | if use_scope is False: |
| 139 | logger.warn( |
| 140 | "Please call layer {} without the first scope name argument, " |
| 141 | "or register the layer with use_scope=None to allow calling it " |
| 142 | "with scope names.".format(func.__name__)) |
| 143 | name, inputs = args[0], args[1] |
| 144 | args = args[1:] # actual positional args used to call func |
| 145 | else: |
| 146 | inputs = args[0] |
| 147 | name = None |
| 148 | if not (isinstance(inputs, (tf.Tensor, tf.Variable)) or |
| 149 | (isinstance(inputs, (list, tuple)) and |
| 150 | isinstance(inputs[0], (tf.Tensor, tf.Variable)))): |
| 151 | raise ValueError("Invalid inputs to layer: " + str(inputs)) |
| 152 | |
| 153 | # use kwargs from current argument scope |
| 154 | actual_args = copy.copy(get_arg_scope()[func.__name__]) |
| 155 | # explicit kwargs overwrite argscope |
| 156 | actual_args.update(kwargs) |
| 157 | # if six.PY3: |
| 158 | # # explicit positional args also override argscope. only work in PY3 |
| 159 | # posargmap = inspect.signature(func).bind_partial(*args).arguments |
| 160 | # for k in six.iterkeys(posargmap): |
| 161 | # if k in actual_args: |
| 162 | # del actual_args[k] |
| 163 | |
| 164 | if name is not None: # use scope |
| 165 | with tfv1.variable_scope(name) as scope: |
| 166 | # this name is only used to surpress logging, doesn't hurt to do some heuristics |
| 167 | scope_name = re.sub('tower[0-9]+/', '', scope.name) |
| 168 | do_log_shape = log_shape and scope_name not in _LAYER_LOGGED |
| 169 | if do_log_shape: |
| 170 | _SHAPE_LOGGER.push_inputs(scope.name, get_shape_str(inputs)) |
| 171 | |
| 172 | # run the actual function |
| 173 | outputs = func(*args, **actual_args) |
| 174 | |
| 175 | if do_log_shape: |
| 176 | _SHAPE_LOGGER.push_outputs(scope.name, get_shape_str(outputs)) |
| 177 | _LAYER_LOGGED.add(scope_name) |
| 178 | else: |
| 179 | # run the actual function |
| 180 | outputs = func(*args, **actual_args) |
| 181 | return outputs |
| 182 | |
| 183 | wrapped_func.use_scope = use_scope |
| 184 | wrapped_func.__argscope_enabled__ = True |
nothing calls this directly
no test coverage detected