Internal: rescale encoder model. Args: model: Model instance or model name. input_data: TODO.
(model, input_data)
| 218 | |
| 219 | |
| 220 | def _rescale_encoder_model(model, input_data): |
| 221 | # Calculate absmax |
| 222 | """Internal: rescale encoder model. |
| 223 | |
| 224 | Args: |
| 225 | model: Model instance or model name. |
| 226 | input_data: TODO. |
| 227 | """ |
| 228 | absmax = torch.tensor(0).cuda() |
| 229 | |
| 230 | def stat_input_hook(m, x, y): |
| 231 | """Stat input hook. |
| 232 | |
| 233 | Args: |
| 234 | m: TODO. |
| 235 | x: TODO. |
| 236 | y: TODO. |
| 237 | """ |
| 238 | val = x[0] if isinstance(x, tuple) else x |
| 239 | absmax.copy_(torch.max(absmax, val.detach().abs().max())) |
| 240 | |
| 241 | encoders = model.encoder.model.encoders |
| 242 | hooks = [m.register_forward_hook(stat_input_hook) for m in encoders] |
| 243 | model = model.cuda() |
| 244 | model(*input_data) |
| 245 | for h in hooks: |
| 246 | h.remove() |
| 247 | |
| 248 | # Rescale encoder modules |
| 249 | fp16_scale = int(2 * absmax // 65536) |
| 250 | print(f"rescale encoder modules with factor={fp16_scale}\n\n") |
| 251 | model.encoder.model.encoders0.register_forward_pre_hook( |
| 252 | functools.partial(_rescale_input_hook, scale=fp16_scale), |
| 253 | ) |
| 254 | for name, m in model.encoder.model.named_modules(): |
| 255 | if name.endswith("self_attn"): |
| 256 | m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale)) |
| 257 | if name.endswith("feed_forward.w_2"): |
| 258 | state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()} |
| 259 | m.load_state_dict(state_dict) |
| 260 | |
| 261 | |
| 262 | def _bladedisc_opt_for_encdec(model, path, enable_fp16): |
no test coverage detected
searching dependent graphs…