(model, context, max_gen_length, eos_token_id)
| 223 | # codebase. The lm-evaluation-harness code can now call this function |
| 224 | # similar to their current generate function call used for gpt style models. |
| 225 | def generate_samples_eval(model, context, max_gen_length, eos_token_id): |
| 226 | # Generate samples for lm evaluation |
| 227 | # NEED TO THINK ABOUT eos token |
| 228 | |
| 229 | args = get_args() |
| 230 | tokenizer = get_tokenizer() |
| 231 | |
| 232 | raw_text_len = len(context) |
| 233 | model.eval() |
| 234 | |
| 235 | context_tokens = tokenizer.tokenize(context) |
| 236 | args.out_seq_length = max_gen_length + len(context_tokens) |
| 237 | args.eos_id = eos_token_id |
| 238 | |
| 239 | with torch.no_grad(): |
| 240 | token_stream = get_token_stream(model, [context_tokens]) |
| 241 | for counter, decode_tokens in enumerate(token_stream): |
| 242 | if counter == args.out_seq_length: |
| 243 | break |
| 244 | |
| 245 | decode_tokens, _ = decode_tokens |
| 246 | decode_tokens = decode_tokens[0].cpu().numpy().tolist() |
| 247 | trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:] |
| 248 | |
| 249 | return trim_decode_tokens |
| 250 | |
| 251 | |
| 252 | def generate_samples_interactive_code_contest(model, print_frequency=10): |
nothing calls this directly
no test coverage detected