MCPcopy Index your code
hub / github.com/THUDM/GLM / generate_samples

Function generate_samples

generate_samples.py:239–286  ·  view source on GitHub ↗
(model, tokenizer, args, device)

Source from the content-addressed store, hash-verified

237
238
239def generate_samples(model, tokenizer, args, device):
240 model.eval()
241 output_path = "./samples"
242 if not os.path.exists(output_path):
243 os.makedirs(output_path)
244 output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
245 with torch.no_grad(), open(output_path, "w") as output:
246 while True:
247 torch.distributed.barrier(group=mpu.get_model_parallel_group())
248
249 terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
250 if terminate_runs == 1:
251 return
252 start_time = time.time()
253 if args.block_lm:
254 mems = []
255 tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
256 mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
257 mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
258 end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
259 mask_positions = []
260 for token in mask_tokens:
261 mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
262 mask_positions.sort()
263 if args.no_block_position:
264 for mask_position in mask_positions:
265 position_ids[0, mask_position + 1:] += args.out_seq_length
266 _, *mems = model(tokens, position_ids, attention_mask, *mems)
267 for mask_position in mask_positions:
268 if args.no_block_position:
269 position = position_ids[0, mask_position].item()
270 else:
271 position = mask_position
272 tokens, mems = sample_sequence(model, tokenizer, tokens, position,
273 args, device, mems=mems, end_tokens=end_tokens)
274 else:
275 tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device)
276 output_tokens_list = tokens.view(-1).contiguous()
277 if mpu.get_model_parallel_rank() == 0:
278 os.system('clear')
279 print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
280 print("\nContext:", raw_text, flush=True)
281 decode_tokens = tokenizer.DecodeIds(output_tokens_list[context_length:].tolist())
282 trim_decode_tokens = decode_tokens
283 print("\nGLM:", trim_decode_tokens, flush=True)
284 output.write(trim_decode_tokens + "\n")
285
286 torch.distributed.barrier(group=mpu.get_model_parallel_group())
287
288
289def main():

Callers 1

mainFunction · 0.85

Calls 7

read_contextFunction · 0.85
sample_sequenceFunction · 0.85
get_commandMethod · 0.80
get_batchFunction · 0.70
existsMethod · 0.45
DecodeIdsMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected