(self)
| 146 | |
| 147 | |
| 148 | def _load_model(self): |
| 149 | if FLAGS.load_llama_config != '': |
| 150 | llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) |
| 151 | updates = VideoLLaMAConfig(**FLAGS.llama) |
| 152 | llama_config.update(dict( |
| 153 | scan_attention=updates.scan_attention, |
| 154 | scan_mlp=updates.scan_mlp, |
| 155 | scan_query_chunk_size=updates.scan_query_chunk_size, |
| 156 | scan_key_chunk_size=updates.scan_key_chunk_size, |
| 157 | scan_mlp_chunk_size=updates.scan_mlp_chunk_size, |
| 158 | scan_layers=updates.scan_layers, |
| 159 | param_scan_axis=updates.param_scan_axis, |
| 160 | )) |
| 161 | else: |
| 162 | llama_config = VideoLLaMAConfig(**FLAGS.llama) |
| 163 | |
| 164 | if FLAGS.update_llama_config != '': |
| 165 | llama_config.update(dict(eval(FLAGS.update_llama_config))) |
| 166 | |
| 167 | llama_config.update(dict( |
| 168 | bos_token_id=self.tokenizer.bos_token_id, |
| 169 | eos_token_id=self.tokenizer.eos_token_id, |
| 170 | )) |
| 171 | llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) |
| 172 | self.config = llama_config |
| 173 | |
| 174 | self.model = FlaxVideoLLaMAForCausalLM( |
| 175 | llama_config, |
| 176 | input_shape=(512, self.block_size), |
| 177 | seed=FLAGS.seed, |
| 178 | _do_init=False, |
| 179 | dtype=get_float_dtype_by_name(FLAGS.dtype), |
| 180 | ) |
| 181 | |
| 182 | with jax.default_device(jax.devices("cpu")[0]): |
| 183 | _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( |
| 184 | FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 |
| 185 | ) |
| 186 | self.model_ps = match_partition_rules( |
| 187 | VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params |
| 188 | ) |
| 189 | shard_fns, _ = make_shard_and_gather_fns( |
| 190 | self.model_ps, get_float_dtype_by_name(FLAGS.dtype) |
| 191 | ) |
| 192 | |
| 193 | with self.mesh: |
| 194 | self.params = tree_apply(shard_fns, self.params) |
| 195 | |
| 196 | @cached_property |
| 197 | def _forward_generate(self): |
no test coverage detected