(*args, **kwargs)
| 95 | |
| 96 | @wraps(func) |
| 97 | def wrapped_func(*args, **kwargs): |
| 98 | actual_args = copy.copy(get_arg_scope()[func.__name__]) |
| 99 | actual_args.update(kwargs) |
| 100 | out_tensor = func(*args, **actual_args) |
| 101 | in_tensor = args[0] |
| 102 | |
| 103 | ctx = get_current_tower_context() |
| 104 | name = func.__name__ if 'name' not in kwargs else kwargs['name'] |
| 105 | if log_shape: |
| 106 | if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower: |
| 107 | # we assume the first parameter is the most interesting |
| 108 | if isinstance(out_tensor, tuple): |
| 109 | out_tensor_descr = out_tensor[0] |
| 110 | else: |
| 111 | out_tensor_descr = out_tensor |
| 112 | logger.info("{:<12}: {} --> {}".format( |
| 113 | "'" + name + "'", |
| 114 | get_shape_str(in_tensor), |
| 115 | get_shape_str(out_tensor_descr))) |
| 116 | |
| 117 | return out_tensor |
| 118 | wrapped_func.__argscope_enabled__ = True |
| 119 | return wrapped_func |
| 120 |
nothing calls this directly
no test coverage detected