MCPcopy
hub / github.com/PRIME-RL/PRIME / load_fsdp_param_and_grad

Function load_fsdp_param_and_grad

training/verl/utils/fsdp_utils.py:95–102  ·  view source on GitHub ↗
(module, device_id, load_grad=False)

Source from the content-addressed store, hash-verified

93
94
95def load_fsdp_param_and_grad(module, device_id, load_grad=False):
96 for _, param in module.named_parameters():
97 if hasattr(param, "_local_shard"):
98 param._local_shard = param._local_shard.to(device_id, non_blocking=True)
99 param.data = param.data.to(device_id, non_blocking=True)
100 if load_grad and param.grad is not None:
101 param.grad = param.grad.to(device_id, non_blocking=True)
102 torch.cuda.empty_cache()
103
104
105def offload_fsdp_optimizer(optimizer):

Callers 9

update_actorMethod · 0.90
generate_sequencesMethod · 0.90
compute_ref_log_probMethod · 0.90
save_checkpointMethod · 0.90
compute_valuesMethod · 0.90
update_criticMethod · 0.90
save_checkpointMethod · 0.90
compute_rm_scoreMethod · 0.90
save_checkpointMethod · 0.90

Calls 2

named_parametersMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected