Args: mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} Returns: A context where the variables are renamed.
(mapping)
| 71 | |
| 72 | |
| 73 | def rename_get_variable(mapping): |
| 74 | """ |
| 75 | Args: |
| 76 | mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} |
| 77 | |
| 78 | Returns: |
| 79 | A context where the variables are renamed. |
| 80 | """ |
| 81 | def custom_getter(getter, name, *args, **kwargs): |
| 82 | splits = name.split('/') |
| 83 | basename = splits[-1] |
| 84 | if basename in mapping: |
| 85 | basename = mapping[basename] |
| 86 | splits[-1] = basename |
| 87 | name = '/'.join(splits) |
| 88 | return getter(name, *args, **kwargs) |
| 89 | return custom_getter_scope(custom_getter) |
| 90 | |
| 91 | |
| 92 | def rename_tflayer_get_variable(): |
no test coverage detected