(self)
| 309 | |
| 310 | class Sampler: |
| 311 | def __init__(self): |
| 312 | self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) |
| 313 | self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') |
| 314 | self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) |
| 315 | self.sharded_rng = next_rng() |
| 316 | self._load_model() |
| 317 | |
| 318 | @property |
| 319 | def block_size(self): |
nothing calls this directly
no test coverage detected