MCPcopy
hub / github.com/zai-org/GLM-130B / initialize_model_and_tokenizer

Function initialize_model_and_tokenizer

initialize.py:55–116  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

53
54
55def initialize_model_and_tokenizer(args):
56 tokenizer = get_tokenizer(args)
57
58 torch.distributed.barrier()
59 start = time.time()
60
61 for i in range(get_model_parallel_world_size()):
62 if get_model_parallel_rank() == i:
63 # Initialize model
64 model = GLM130B(args).half()
65
66 if args.from_quantized_checkpoint:
67 assert args.quantization_bit_width is not None
68 # Quantize model before moving to GPU
69 model = quantize(model, args.quantization_bit_width)
70
71 # Load checkpoint
72 load_checkpoint(model, args)
73
74 if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
75 # Quantize model before moving to GPU
76 model = quantize(model, args.quantization_bit_width)
77
78 if args.bminf:
79 import bminf
80
81 if torch.distributed.get_rank() == 0:
82 print(f"> BMInf activated, memory limit: {args.bminf_memory_limit} GB")
83 with torch.cuda.device(args.device):
84 model = bminf.wrapper(model, quantization=False, memory_limit=args.bminf_memory_limit << 30)
85 else:
86 model = model.to(args.device)
87 if args.sequential_initialization:
88 torch.distributed.barrier(group=get_model_parallel_group())
89
90 torch.distributed.barrier()
91 if torch.distributed.get_rank() == 0:
92 print(f"> Model initialized in {time.time() - start:.1f}s")
93
94 torch.cuda.empty_cache()
95 model.eval()
96
97 # generate rotary embedding cache
98 original_parallel_output = model.transformer.parallel_output
99 model.transformer.parallel_output = True
100 with torch.no_grad():
101 _, *_ = model(
102 torch.ones(1, args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64),
103 torch.arange(args.max_sequence_length, device=torch.cuda.current_device(), dtype=torch.int64).view(1, -1),
104 torch.randn(
105 1,
106 1,
107 args.max_sequence_length,
108 args.max_sequence_length,
109 device=torch.cuda.current_device(),
110 )
111 < 0.5,
112 )

Callers 3

benchmark.pyFile · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 1

quantizeFunction · 0.90

Tested by

no test coverage detected