(args)
| 53 | |
| 54 | |
| 55 | def initialize_model_and_tokenizer(args): |
| 56 | tokenizer = get_tokenizer(args) |
| 57 | |
| 58 | torch.distributed.barrier() |
| 59 | start = time.time() |
| 60 | |
| 61 | for i in range(get_model_parallel_world_size()): |
| 62 | if get_model_parallel_rank() == i: |
| 63 | # Initialize model |
| 64 | model = GLM130B(args).half() |
| 65 | |
| 66 | if args.from_quantized_checkpoint: |
| 67 | assert args.quantization_bit_width is not None |
| 68 | # Quantize model before moving to GPU |
| 69 | model = quantize(model, args.quantization_bit_width) |
| 70 | |
| 71 | # Load checkpoint |
| 72 | load_checkpoint(model, args) |
| 73 | |
| 74 | if args.quantization_bit_width is not None and not args.from_quantized_checkpoint: |
| 75 | # Quantize model before moving to GPU |
| 76 | model = quantize(model, args.quantization_bit_width) |
| 77 | |
| 78 | if args.bminf: |
| 79 | import bminf |
| 80 | |
| 81 | if torch.distributed.get_rank() == 0: |
| 82 | print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB") |
| 83 | with torch.cuda.device(args.device): |
| 84 | model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30) |
| 85 | else: |
| 86 | model = model.to(args.device) |
| 87 | if args.sequential_initialization: |
| 88 | torch.distributed.barrier(group=get_model_parallel_group()) |
| 89 | |
| 90 | torch.distributed.barrier() |
| 91 | if torch.distributed.get_rank() == 0: |
| 92 | print(f"> Model initialized in {time.time() - start:.1f}s") |
| 93 | |
| 94 | torch.cuda.empty_cache() |
| 95 | model.eval() |
| 96 | |
| 97 | # generate rotary embedding cache |
| 98 | original_parallel_output = model.transformer.parallel_output |
| 99 | model.transformer.parallel_output = True |
| 100 | with torch.no_grad(): |
| 101 | _, *_ = model( |
| 102 | torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64), |
| 103 | torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1), |
| 104 | torch.randn( |
| 105 | 1, |
| 106 | 1, |
| 107 | args.max_sequence_length, |
| 108 | args.max_sequence_length, |
| 109 | device=torch.cuda.current_device(), |
| 110 | ) |
| 111 | < 0.5, |
| 112 | ) |
no test coverage detected