MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / _load_model

Method _load_model

scripts/eval_needle_multi.py:336–381  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

334 return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
335
336 def _load_model(self):
337 if FLAGS.load_llama_config != '':
338 llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
339 updates = LLaMAConfig(**FLAGS.llama)
340 llama_config.update(dict(
341 scan_attention=updates.scan_attention,
342 scan_mlp=updates.scan_mlp,
343 scan_query_chunk_size=updates.scan_query_chunk_size,
344 scan_key_chunk_size=updates.scan_key_chunk_size,
345 scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
346 scan_layers=updates.scan_layers,
347 param_scan_axis=updates.param_scan_axis,
348 ))
349 else:
350 llama_config = LLaMAConfig(**FLAGS.llama)
351
352 if FLAGS.update_llama_config != '':
353 llama_config.update(dict(eval(FLAGS.update_llama_config)))
354
355 llama_config.update(dict(
356 bos_token_id=self.tokenizer.bos_token_id,
357 eos_token_id=self.tokenizer.eos_token_id,
358 ))
359 llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
360 self.config = llama_config
361
362 with jax.default_device(jax.devices("cpu")[0]):
363 _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
364 FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
365 )
366 self.model = FlaxLLaMAForCausalLM(
367 llama_config,
368 input_shape=(512, self.block_size),
369 seed=FLAGS.seed,
370 _do_init=False,
371 dtype=get_float_dtype_by_name(FLAGS.dtype),
372 )
373 self.model_ps = match_partition_rules(
374 LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
375 )
376 shard_fns, _ = make_shard_and_gather_fns(
377 self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
378 )
379
380 with self.mesh:
381 self.params = tree_apply(shard_fns, self.params)
382
383 @cached_property
384 def _forward_generate(self):

Callers 1

__init__Method · 0.95

Calls 4

LLaMAConfigClass · 0.90
load_configMethod · 0.45
get_partition_rulesMethod · 0.45

Tested by

no test coverage detected