MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / _load_model

Method _load_model

scripts/eval_needle.py:327–372  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

325 return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
326
327 def _load_model(self):
328 if FLAGS.load_llama_config != '':
329 llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
330 updates = LLaMAConfig(**FLAGS.llama)
331 llama_config.update(dict(
332 scan_attention=updates.scan_attention,
333 scan_mlp=updates.scan_mlp,
334 scan_query_chunk_size=updates.scan_query_chunk_size,
335 scan_key_chunk_size=updates.scan_key_chunk_size,
336 scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
337 scan_layers=updates.scan_layers,
338 param_scan_axis=updates.param_scan_axis,
339 ))
340 else:
341 llama_config = LLaMAConfig(**FLAGS.llama)
342
343 if FLAGS.update_llama_config != '':
344 llama_config.update(dict(eval(FLAGS.update_llama_config)))
345
346 llama_config.update(dict(
347 bos_token_id=self.tokenizer.bos_token_id,
348 eos_token_id=self.tokenizer.eos_token_id,
349 ))
350 llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
351 self.config = llama_config
352
353 with jax.default_device(jax.devices("cpu")[0]):
354 _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
355 FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
356 )
357 self.model = FlaxLLaMAForCausalLM(
358 llama_config,
359 input_shape=(512, self.block_size),
360 seed=FLAGS.seed,
361 _do_init=False,
362 dtype=get_float_dtype_by_name(FLAGS.dtype),
363 )
364 self.model_ps = match_partition_rules(
365 LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
366 )
367 shard_fns, _ = make_shard_and_gather_fns(
368 self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
369 )
370
371 with self.mesh:
372 self.params = tree_apply(shard_fns, self.params)
373
374 @cached_property
375 def _forward_generate(self):

Callers 1

__init__Method · 0.95

Calls 4

LLaMAConfigClass · 0.90
load_configMethod · 0.45
get_partition_rulesMethod · 0.45

Tested by

no test coverage detected