(tokenizer, args, output)
| 183 | |
| 184 | |
| 185 | def read_context(tokenizer, args, output): |
| 186 | terminate_runs, skip_run = 0, 0 |
| 187 | if mpu.get_model_parallel_rank() == 0: |
| 188 | while True: |
| 189 | raw_text = input("\nContext prompt (stop to exit) >>> ") |
| 190 | if not raw_text: |
| 191 | print('Prompt should not be empty!') |
| 192 | continue |
| 193 | if raw_text == "stop": |
| 194 | terminate_runs = 1 |
| 195 | break |
| 196 | generation_mask = '[gMASK]' if args.task_mask else '[MASK]' |
| 197 | if args.block_lm and 'MASK]' not in raw_text: |
| 198 | raw_text += ' ' + generation_mask |
| 199 | output.write(raw_text) |
| 200 | context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization |
| 201 | if args.block_lm: |
| 202 | context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens |
| 203 | if not raw_text.endswith('[gMASK]'): |
| 204 | context_tokens = context_tokens + [tokenizer.get_command('eos').Id] |
| 205 | context_length = len(context_tokens) |
| 206 | |
| 207 | if context_length >= args.seq_length: |
| 208 | print("\nContext length", context_length, |
| 209 | "\nPlease give smaller context than the window length!") |
| 210 | continue |
| 211 | break |
| 212 | else: |
| 213 | context_length = 0 |
| 214 | |
| 215 | terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) |
| 216 | torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), |
| 217 | group=mpu.get_model_parallel_group()) |
| 218 | terminate_runs = terminate_runs_tensor[0].item() |
| 219 | |
| 220 | if terminate_runs == 1: |
| 221 | return terminate_runs, None, None, None |
| 222 | |
| 223 | context_length_tensor = torch.cuda.LongTensor([context_length]) |
| 224 | |
| 225 | torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), |
| 226 | group=mpu.get_model_parallel_group()) |
| 227 | context_length = context_length_tensor[0].item() |
| 228 | if mpu.get_model_parallel_rank() == 0: |
| 229 | context_tokens_tensor = torch.cuda.LongTensor(context_tokens) |
| 230 | else: |
| 231 | context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) |
| 232 | torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), |
| 233 | group=mpu.get_model_parallel_group()) |
| 234 | if mpu.get_model_parallel_rank() != 0: |
| 235 | raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) |
| 236 | return terminate_runs, raw_text, context_tokens_tensor, context_length |
| 237 | |
| 238 | |
| 239 | def generate_samples(model, tokenizer, args, device): |
no test coverage detected