MCPcopy
hub / github.com/hpcaitech/Open-Sora / load

Function load

opensora/utils/ckpt_utils.py:346–386  ·  view source on GitHub ↗
(
    booster: Booster,
    load_dir: str,
    model: nn.Module = None,
    ema: nn.Module = None,
    optimizer: Optimizer = None,
    lr_scheduler: _LRScheduler = None,
    sampler=None,
    is_lora_train: bool = False,
    is_load_both_lora_and_main: bool = False,
)

Source from the content-addressed store, hash-verified

344
345
346def load(
347 booster: Booster,
348 load_dir: str,
349 model: nn.Module = None,
350 ema: nn.Module = None,
351 optimizer: Optimizer = None,
352 lr_scheduler: _LRScheduler = None,
353 sampler=None,
354 is_lora_train: bool = False,
355 is_load_both_lora_and_main: bool = False,
356) -> Tuple[int, int, int]:
357 assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist"
358 assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist"
359 running_states = load_json(os.path.join(load_dir, "running_states.json"))
360 if model is not None:
361 if is_lora_train:
362 if is_load_both_lora_and_main:
363 booster.load_model(model, os.path.join(load_dir, "lora", "adapter_model.bin"), strict=False)
364 booster.load_model(model, os.path.join(load_dir, "model"), strict=False)
365 else:
366 booster.load_model(model, os.path.join(load_dir, "lora", "adapter_model.bin"), strict=False)
367 else:
368 booster.load_model(model, os.path.join(load_dir, "model"), strict=False)
369 if ema is not None:
370 # ema is not boosted, so we don't use booster.load_model
371 ema.load_state_dict(
372 torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
373 strict=False,
374 )
375 if optimizer is not None:
376 booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
377 if lr_scheduler is not None:
378 booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
379 if sampler is not None:
380 sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
381 dist.barrier()
382
383 return (
384 running_states["epoch"],
385 running_states["step"],
386 )
387
388
389def rm_checkpoints(

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 3

load_jsonFunction · 0.85
load_state_dictMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected