(config)
| 16 | |
| 17 | |
| 18 | def build_transform(config): |
| 19 | from .tps import TPS |
| 20 | from .stn import STN_ON |
| 21 | from .tsrn import TSRN |
| 22 | from .tbsrn import TBSRN |
| 23 | from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN |
| 24 | |
| 25 | support_dict = ["TPS", "STN_ON", "GA_SPIN", "TSRN", "TBSRN"] |
| 26 | |
| 27 | module_name = config.pop("name") |
| 28 | assert module_name in support_dict, Exception( |
| 29 | "transform only support {}".format(support_dict) |
| 30 | ) |
| 31 | module_class = eval(module_name)(**config) |
| 32 | return module_class |