(model, tokenizer)
| 80 | return metadata |
| 81 | |
| 82 | def validate_model(model, tokenizer): |
| 83 | prompt = "yo" |
| 84 | toks = [tokenizer.bos_id] |
| 85 | toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("user") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n") |
| 86 | toks += tokenizer.encode(prompt) + [tokenizer.special_tokens["<|eot_id|>"]] |
| 87 | toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("assistant") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n") |
| 88 | start_pos = 0 |
| 89 | run = TinyJit(model.forward) |
| 90 | for tok in toks[:-1]: |
| 91 | run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).realize() |
| 92 | start_pos += 1 |
| 93 | tok = toks[-1] |
| 94 | result = "" |
| 95 | expected = "How's it going?" |
| 96 | while True: |
| 97 | tok = run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).item() |
| 98 | start_pos += 1 |
| 99 | if tok in tokenizer.stop_tokens or len(result) > len(expected): break |
| 100 | result += tokenizer.decode([tok]) |
| 101 | assert result == expected, f"Model validation failed, expected output: {expected}, actual output: {result}" |
| 102 | |
| 103 | if __name__=="__main__": |
| 104 | # Export BPE data for use with tiktoken.js |
no test coverage detected
searching dependent graphs…