(self)
| 334 | return self.mesh.shape['dp'] * self.mesh.shape['fsdp'] |
| 335 | |
| 336 | def _load_model(self): |
| 337 | if FLAGS.load_llama_config != '': |
| 338 | llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) |
| 339 | updates = LLaMAConfig(**FLAGS.llama) |
| 340 | llama_config.update(dict( |
| 341 | scan_attention=updates.scan_attention, |
| 342 | scan_mlp=updates.scan_mlp, |
| 343 | scan_query_chunk_size=updates.scan_query_chunk_size, |
| 344 | scan_key_chunk_size=updates.scan_key_chunk_size, |
| 345 | scan_mlp_chunk_size=updates.scan_mlp_chunk_size, |
| 346 | scan_layers=updates.scan_layers, |
| 347 | param_scan_axis=updates.param_scan_axis, |
| 348 | )) |
| 349 | else: |
| 350 | llama_config = LLaMAConfig(**FLAGS.llama) |
| 351 | |
| 352 | if FLAGS.update_llama_config != '': |
| 353 | llama_config.update(dict(eval(FLAGS.update_llama_config))) |
| 354 | |
| 355 | llama_config.update(dict( |
| 356 | bos_token_id=self.tokenizer.bos_token_id, |
| 357 | eos_token_id=self.tokenizer.eos_token_id, |
| 358 | )) |
| 359 | llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) |
| 360 | self.config = llama_config |
| 361 | |
| 362 | with jax.default_device(jax.devices("cpu")[0]): |
| 363 | _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( |
| 364 | FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 |
| 365 | ) |
| 366 | self.model = FlaxLLaMAForCausalLM( |
| 367 | llama_config, |
| 368 | input_shape=(512, self.block_size), |
| 369 | seed=FLAGS.seed, |
| 370 | _do_init=False, |
| 371 | dtype=get_float_dtype_by_name(FLAGS.dtype), |
| 372 | ) |
| 373 | self.model_ps = match_partition_rules( |
| 374 | LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params |
| 375 | ) |
| 376 | shard_fns, _ = make_shard_and_gather_fns( |
| 377 | self.model_ps, get_float_dtype_by_name(FLAGS.dtype) |
| 378 | ) |
| 379 | |
| 380 | with self.mesh: |
| 381 | self.params = tree_apply(shard_fns, self.params) |
| 382 | |
| 383 | @cached_property |
| 384 | def _forward_generate(self): |
no test coverage detected