Convert primitive modules to float16.
(l)
| 17 | |
| 18 | |
| 19 | def convert_module_to_f16(l): |
| 20 | """ |
| 21 | Convert primitive modules to float16. |
| 22 | """ |
| 23 | if isinstance(l, MIX_PRECISION_MODULES): |
| 24 | for p in l.parameters(): |
| 25 | p.data = p.data.half() |
| 26 | |
| 27 | |
| 28 | def convert_module_to_f32(l): |