(module, prefix="")
| 70 | |
| 71 | |
| 72 | def get_compressed_list(module, prefix=""): |
| 73 | compressed_list = [] |
| 74 | for attr_str in dir(module): |
| 75 | target_attr = getattr(module, attr_str) |
| 76 | if type(target_attr) == torch.nn.Linear: |
| 77 | full_name = ( |
| 78 | f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" |
| 79 | ) |
| 80 | compressed_list.append(full_name) |
| 81 | for name, child in module.named_children(): |
| 82 | child_prefix = f"{prefix}.{name}" if prefix else name |
| 83 | for each in get_compressed_list(child, child_prefix): |
| 84 | compressed_list.append(each) |
| 85 | return compressed_list |
| 86 | |
| 87 | |
| 88 | def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): |
no outgoing calls
no test coverage detected
searching dependent graphs…