(self)
| 325 | return self.mesh.shape['dp'] * self.mesh.shape['fsdp'] |
| 326 | |
| 327 | def _load_model(self): |
| 328 | if FLAGS.load_llama_config != '': |
| 329 | llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) |
| 330 | updates = LLaMAConfig(**FLAGS.llama) |
| 331 | llama_config.update(dict( |
| 332 | scan_attention=updates.scan_attention, |
| 333 | scan_mlp=updates.scan_mlp, |
| 334 | scan_query_chunk_size=updates.scan_query_chunk_size, |
| 335 | scan_key_chunk_size=updates.scan_key_chunk_size, |
| 336 | scan_mlp_chunk_size=updates.scan_mlp_chunk_size, |
| 337 | scan_layers=updates.scan_layers, |
| 338 | param_scan_axis=updates.param_scan_axis, |
| 339 | )) |
| 340 | else: |
| 341 | llama_config = LLaMAConfig(**FLAGS.llama) |
| 342 | |
| 343 | if FLAGS.update_llama_config != '': |
| 344 | llama_config.update(dict(eval(FLAGS.update_llama_config))) |
| 345 | |
| 346 | llama_config.update(dict( |
| 347 | bos_token_id=self.tokenizer.bos_token_id, |
| 348 | eos_token_id=self.tokenizer.eos_token_id, |
| 349 | )) |
| 350 | llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) |
| 351 | self.config = llama_config |
| 352 | |
| 353 | with jax.default_device(jax.devices("cpu")[0]): |
| 354 | _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( |
| 355 | FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 |
| 356 | ) |
| 357 | self.model = FlaxLLaMAForCausalLM( |
| 358 | llama_config, |
| 359 | input_shape=(512, self.block_size), |
| 360 | seed=FLAGS.seed, |
| 361 | _do_init=False, |
| 362 | dtype=get_float_dtype_by_name(FLAGS.dtype), |
| 363 | ) |
| 364 | self.model_ps = match_partition_rules( |
| 365 | LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params |
| 366 | ) |
| 367 | shard_fns, _ = make_shard_and_gather_fns( |
| 368 | self.model_ps, get_float_dtype_by_name(FLAGS.dtype) |
| 369 | ) |
| 370 | |
| 371 | with self.mesh: |
| 372 | self.params = tree_apply(shard_fns, self.params) |
| 373 | |
| 374 | @cached_property |
| 375 | def _forward_generate(self): |
no test coverage detected