MCPcopy
hub / github.com/Tele-AI/Telechat / main

Function main

deepspeed-telechat/sft/main.py:241–401  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

239 return loss
240
241def main():
242 args = parse_args()
243
244 if args.local_rank == -1:
245 device = torch.device("cuda")
246 else:
247 torch.cuda.set_device(args.local_rank)
248 device = torch.device("cuda", args.local_rank)
249 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
250 # torch.distributed.init_process_group(backend='nccl')
251 deepspeed.init_distributed()
252
253 args.global_rank = torch.distributed.get_rank()
254
255 ds_config = get_train_ds_config(offload=args.offload,
256 stage=args.zero_stage,
257 precision=args.precision)
258 ds_config[
259 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
260 ds_config[
261 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size(
262 ) * args.gradient_accumulation_steps
263 loss_update_steps = args.per_device_train_batch_size * args.gradient_accumulation_steps
264
265 # If passed along, set the training seed now.
266 set_random_seed(args.seed)
267
268 torch.distributed.barrier()
269
270 tokenizer = load_telechat_tokenizer(args.model_name_or_path, fast_tokenizer=True)
271 args.user_token_id = tokenizer.convert_tokens_to_ids(args.user_token)
272 args.bot_token_id = tokenizer.convert_tokens_to_ids(args.bot_token)
273 args.end_token_id = tokenizer.convert_tokens_to_ids(args.end_token)
274
275
276 model = create_hf_telechat(args.model_name_or_path,
277 args.precision,
278 ds_config,
279 disable_dropout=args.disable_dropout)
280
281 if args.lora_dim > 0:
282 model = convert_linear_layer_to_lora(model, args.lora_module_name, args.lora_scaling,
283 args.lora_dim)
284 if args.mark_only_lora_as_trainable:
285 mark_only_lora_as_trainable(model, 'lora_only')
286 make_model_gradient_checkpointing_compatible(model)
287 if args.only_optimize_lora:
288 model = only_optimize_lora_parameters(model)
289
290 # Prepare the data
291 print(f"train_fname:{args.data_path}")
292 assert os.path.exists(args.data_path), "Please process data first!"
293 torch.distributed.barrier()
294 train_dataset = get_dataset(args.data_path, args.seed)
295
296 # DataLoaders creation:
297 if args.local_rank == -1:
298 train_sampler = RandomSampler(train_dataset)

Callers 1

main.pyFile · 0.70

Calls 15

get_train_ds_configFunction · 0.90
set_random_seedFunction · 0.90
get_datasetFunction · 0.90
print_rank_0Function · 0.90
to_deviceFunction · 0.90
save_hf_formatFunction · 0.90

Tested by

no test coverage detected