MCPcopy Index your code
hub / github.com/modelscope/FunASR / _rescale_encoder_model

Function _rescale_encoder_model

funasr/utils/export_utils.py:220–259  ·  view source on GitHub ↗

Internal: rescale encoder model. Args: model: Model instance or model name. input_data: TODO.

(model, input_data)

Source from the content-addressed store, hash-verified

218
219
220def _rescale_encoder_model(model, input_data):
221 # Calculate absmax
222 """Internal: rescale encoder model.
223
224 Args:
225 model: Model instance or model name.
226 input_data: TODO.
227 """
228 absmax = torch.tensor(0).cuda()
229
230 def stat_input_hook(m, x, y):
231 """Stat input hook.
232
233 Args:
234 m: TODO.
235 x: TODO.
236 y: TODO.
237 """
238 val = x[0] if isinstance(x, tuple) else x
239 absmax.copy_(torch.max(absmax, val.detach().abs().max()))
240
241 encoders = model.encoder.model.encoders
242 hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
243 model = model.cuda()
244 model(*input_data)
245 for h in hooks:
246 h.remove()
247
248 # Rescale encoder modules
249 fp16_scale = int(2 * absmax // 65536)
250 print(f"rescale encoder modules with factor={fp16_scale}\n\n")
251 model.encoder.model.encoders0.register_forward_pre_hook(
252 functools.partial(_rescale_input_hook, scale=fp16_scale),
253 )
254 for name, m in model.encoder.model.named_modules():
255 if name.endswith("self_attn"):
256 m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale))
257 if name.endswith("feed_forward.w_2"):
258 state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
259 m.load_state_dict(state_dict)
260
261
262def _bladedisc_opt_for_encdec(model, path, enable_fp16):

Callers 2

_onnx_opt_for_encdecFunction · 0.85

Calls 2

state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…