MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / _load_model

Method _load_model

lwm/vision_chat.py:148–194  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

146
147
148 def _load_model(self):
149 if FLAGS.load_llama_config != '':
150 llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
151 updates = VideoLLaMAConfig(**FLAGS.llama)
152 llama_config.update(dict(
153 scan_attention=updates.scan_attention,
154 scan_mlp=updates.scan_mlp,
155 scan_query_chunk_size=updates.scan_query_chunk_size,
156 scan_key_chunk_size=updates.scan_key_chunk_size,
157 scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
158 scan_layers=updates.scan_layers,
159 param_scan_axis=updates.param_scan_axis,
160 ))
161 else:
162 llama_config = VideoLLaMAConfig(**FLAGS.llama)
163
164 if FLAGS.update_llama_config != '':
165 llama_config.update(dict(eval(FLAGS.update_llama_config)))
166
167 llama_config.update(dict(
168 bos_token_id=self.tokenizer.bos_token_id,
169 eos_token_id=self.tokenizer.eos_token_id,
170 ))
171 llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
172 self.config = llama_config
173
174 self.model = FlaxVideoLLaMAForCausalLM(
175 llama_config,
176 input_shape=(512, self.block_size),
177 seed=FLAGS.seed,
178 _do_init=False,
179 dtype=get_float_dtype_by_name(FLAGS.dtype),
180 )
181
182 with jax.default_device(jax.devices("cpu")[0]):
183 _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
184 FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
185 )
186 self.model_ps = match_partition_rules(
187 VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
188 )
189 shard_fns, _ = make_shard_and_gather_fns(
190 self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
191 )
192
193 with self.mesh:
194 self.params = tree_apply(shard_fns, self.params)
195
196 @cached_property
197 def _forward_generate(self):

Callers 1

__init__Method · 0.95

Calls 4

VideoLLaMAConfigClass · 0.90
load_configMethod · 0.45
get_partition_rulesMethod · 0.45

Tested by

no test coverage detected