| 57 | return conversion_helper(val, float_conversion) |
| 58 | |
| 59 | class FP16_Module(nn.Module): |
| 60 | def __init__(self, module): |
| 61 | super(FP16_Module, self).__init__() |
| 62 | self.add_module('module', module.half()) |
| 63 | |
| 64 | def forward(self, *inputs, **kwargs): |
| 65 | return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) |
| 66 | |
| 67 | def state_dict(self, destination=None, prefix='', keep_vars=False): |
| 68 | return self.module.state_dict(destination, prefix, keep_vars) |
| 69 | |
| 70 | def load_state_dict(self, state_dict, strict=True): |
| 71 | self.module.load_state_dict(state_dict, strict=strict) |
| 72 | |
| 73 | # TODO: Update overflow check + downscale to use Carl's fused kernel. |
| 74 | class FP16_Optimizer(object): |