(model, tokenizer, args, device)
| 237 | |
| 238 | |
| 239 | def generate_samples(model, tokenizer, args, device): |
| 240 | model.eval() |
| 241 | output_path = "./samples" |
| 242 | if not os.path.exists(output_path): |
| 243 | os.makedirs(output_path) |
| 244 | output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") |
| 245 | with torch.no_grad(), open(output_path, "w") as output: |
| 246 | while True: |
| 247 | torch.distributed.barrier(group=mpu.get_model_parallel_group()) |
| 248 | |
| 249 | terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) |
| 250 | if terminate_runs == 1: |
| 251 | return |
| 252 | start_time = time.time() |
| 253 | if args.block_lm: |
| 254 | mems = [] |
| 255 | tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args) |
| 256 | mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] |
| 257 | mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] |
| 258 | end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] |
| 259 | mask_positions = [] |
| 260 | for token in mask_tokens: |
| 261 | mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() |
| 262 | mask_positions.sort() |
| 263 | if args.no_block_position: |
| 264 | for mask_position in mask_positions: |
| 265 | position_ids[0, mask_position + 1:] += args.out_seq_length |
| 266 | _, *mems = model(tokens, position_ids, attention_mask, *mems) |
| 267 | for mask_position in mask_positions: |
| 268 | if args.no_block_position: |
| 269 | position = position_ids[0, mask_position].item() |
| 270 | else: |
| 271 | position = mask_position |
| 272 | tokens, mems = sample_sequence(model, tokenizer, tokens, position, |
| 273 | args, device, mems=mems, end_tokens=end_tokens) |
| 274 | else: |
| 275 | tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device) |
| 276 | output_tokens_list = tokens.view(-1).contiguous() |
| 277 | if mpu.get_model_parallel_rank() == 0: |
| 278 | os.system('clear') |
| 279 | print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) |
| 280 | print("\nContext:", raw_text, flush=True) |
| 281 | decode_tokens = tokenizer.DecodeIds(output_tokens_list[context_length:].tolist()) |
| 282 | trim_decode_tokens = decode_tokens |
| 283 | print("\nGLM:", trim_decode_tokens, flush=True) |
| 284 | output.write(trim_decode_tokens + "\n") |
| 285 | |
| 286 | torch.distributed.barrier(group=mpu.get_model_parallel_group()) |
| 287 | |
| 288 | |
| 289 | def main(): |
no test coverage detected