MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

examples/discrete_diffusion/train_llada2.py:128–317  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

126
127
128def main():
129 cfg = parse_args()
130 if cfg.prompt_length >= cfg.max_length:
131 raise ValueError("`prompt_length` must be < `max_length`.")
132 if cfg.block_length <= 0:
133 raise ValueError("`block_length` must be > 0.")
134
135 project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
136 accelerator = Accelerator(
137 gradient_accumulation_steps=cfg.gradient_accumulation_steps,
138 project_config=project_config,
139 )
140 if accelerator.is_main_process:
141 os.makedirs(cfg.output_dir, exist_ok=True)
142 accelerator.wait_for_everyone()
143
144 set_seed(cfg.seed)
145 logger.info("Training configuration: %s", asdict(cfg))
146
147 tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
148 if tokenizer.pad_token_id is None:
149 tokenizer.pad_token = tokenizer.eos_token
150
151 if tokenizer.mask_token_id is None:
152 tokenizer.add_special_tokens({"mask_token": "[MASK]"})
153
154 load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
155 model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
156 model.resize_token_embeddings(len(tokenizer))
157 if load_dtype == torch.float32:
158 model.to(dtype=torch.float32)
159
160 mask_token_id = int(tokenizer.mask_token_id)
161
162 if cfg.use_dummy_data:
163 dataset = RandomTokenDataset(
164 num_samples=cfg.num_dummy_samples,
165 seq_len=cfg.max_length,
166 vocab_size=len(tokenizer),
167 pad_token_id=int(tokenizer.pad_token_id),
168 )
169 train_dataloader = DataLoader(
170 dataset,
171 shuffle=True,
172 batch_size=cfg.per_device_train_batch_size,
173 drop_last=True,
174 )
175 else:
176 raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
177 if "train" not in raw_datasets:
178 raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
179
180 with accelerator.main_process_first():
181 tokenized = raw_datasets["train"].map(
182 lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
183 batched=True,
184 remove_columns=raw_datasets["train"].column_names,
185 desc="Tokenizing",

Callers 1

train_llada2.pyFile · 0.70

Calls 15

add_noiseMethod · 0.95
set_seedFunction · 0.90
RandomTokenDatasetClass · 0.85
load_datasetFunction · 0.85
tokenize_fnFunction · 0.85
infoMethod · 0.80
parametersMethod · 0.80
parse_argsFunction · 0.70
get_schedulerFunction · 0.50
from_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…