MCPcopy
hub / github.com/PRIME-RL/PRIME / init_model

Method init_model

training/verl/workers/fsdp_workers.py:251–310  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

249
250 @register(dispatch_mode=Dispatch.ONE_TO_ALL)
251 def init_model(self):
252 from verl.workers.actor import DataParallelPPOActor
253 # This is used to import external_lib into the huggingface systems
254 import_external_libs(self.config.model.get('external_lib', None))
255
256 from omegaconf import OmegaConf
257 override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
258
259 if self._is_actor or self._is_rollout:
260 # we need the model for actor and rollout
261 if self._is_actor:
262 optim_config = self.config.actor.optim
263 fsdp_config = self.config.actor.fsdp_config
264 else:
265 optim_config = None
266 fsdp_config = OmegaConf.create()
267 self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
268 model_path=self.config.model.path,
269 fsdp_config=fsdp_config,
270 optim_config=optim_config,
271 override_model_config=override_model_config,
272 enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
273 trust_remote_code=self.config.model.get('trust_remote_code', False))
274
275 # get the original unwrapped module
276 self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
277
278 if self._is_offload_param:
279 # param is require during state_dict in sharding manager
280 offload_fsdp_grad(module=self.actor_module_fsdp)
281 log_gpu_memory_usage('After offload actor grad during init', logger=logger)
282 if self._is_offload_optimizer:
283 offload_fsdp_optimizer(optimizer=self.actor_optimizer)
284 log_gpu_memory_usage('After offload actor optimizer during init', logger=logger)
285 # load from checkpoint
286 if self._is_actor:
287 OmegaConf.set_struct(self.config.actor, True)
288 self.actor = DataParallelPPOActor(config=self.config.actor,
289 actor_module=self.actor_module_fsdp,
290 actor_optimizer=self.actor_optimizer)
291
292 if self._is_rollout:
293 self.rollout, self.sharding_manager = self._build_rollout()
294
295 if self._is_ref:
296 self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path,
297 fsdp_config=self.config.ref.fsdp_config,
298 optim_config=None,
299 override_model_config=override_model_config,
300 trust_remote_code=self.config.model.get(
301 'trust_remote_code', False))[0]
302 if self._is_offload_param:
303 offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
304
305 OmegaConf.set_struct(self.config.ref, True)
306 self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
307
308 torch.cuda.synchronize()

Callers

nothing calls this directly

Calls 9

_build_rolloutMethod · 0.95
import_external_libsFunction · 0.90
offload_fsdp_gradFunction · 0.90
log_gpu_memory_usageFunction · 0.90
offload_fsdp_optimizerFunction · 0.90
getMethod · 0.45

Tested by

no test coverage detected