(module, layers=[nn.Conv2d, nn.Linear], name='')
| 9 | |
| 10 | |
| 11 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): |
| 12 | if type(module) in layers: |
| 13 | return {name: module} |
| 14 | res = {} |
| 15 | for name1, child in module.named_children(): |
| 16 | res.update(find_layers( |
| 17 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 |
| 18 | )) |
| 19 | return res |
| 20 | |
| 21 | |
| 22 | # code based https://github.com/fpgaminer/GPTQ-triton |
no test coverage detected