MCPcopy Index your code
hub / github.com/modelscope/FunASR / average_checkpoints

Function average_checkpoints

funasr/train_utils/average_nbest_models.py:61–97  ·  view source on GitHub ↗

Average the last 'last_n' checkpoints' model state_dicts. If a tensor is of type torch.int, perform sum instead of average.

(output_dir: str, last_n: int = 5, **kwargs)

Source from the content-addressed store, hash-verified

59
60@torch.no_grad()
61def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
62 """
63 Average the last 'last_n' checkpoints' model state_dicts.
64 If a tensor is of type torch.int, perform sum instead of average.
65 """
66 checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
67 print(f"average_checkpoints: {checkpoint_paths}")
68 state_dicts = []
69
70 # Load state_dicts from checkpoints
71 for path in checkpoint_paths:
72 if os.path.isfile(path):
73 state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
74 else:
75 print(f"Checkpoint file {path} not found.")
76
77 # Check if we have any state_dicts to average
78 if len(state_dicts) < 1:
79 print("No checkpoints found for averaging.")
80 return
81
82 # Average or sum weights
83 avg_state_dict = OrderedDict()
84 for key in state_dicts[0].keys():
85 tensors = [state_dict[key].cpu() for state_dict in state_dicts]
86 # Check the type of the tensor
87 if str(tensors[0].dtype).startswith("torch.int"):
88 # Perform sum for integer tensors
89 summed_tensor = sum(tensors)
90 avg_state_dict[key] = summed_tensor
91 else:
92 # Perform average for other types of tensors
93 stacked_tensors = torch.stack(tensors)
94 avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
95 checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}")
96 torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
97 return checkpoint_outpath

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 2

_get_checkpoint_pathsFunction · 0.85
keysMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…