(argv)
| 42 | |
| 43 | |
| 44 | def main(argv): |
| 45 | assert FLAGS.output_file != '' |
| 46 | if FLAGS.output_file.endswith('mp4'): |
| 47 | assert FLAGS.n_frames > 1 |
| 48 | elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'): |
| 49 | assert FLAGS.n_frames == 1 |
| 50 | else: |
| 51 | raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}") |
| 52 | |
| 53 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) |
| 54 | set_random_seed(FLAGS.seed) |
| 55 | |
| 56 | tokens_per_frame = 257 |
| 57 | vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False) |
| 58 | mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) |
| 59 | tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) |
| 60 | prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') |
| 61 | if FLAGS.load_llama_config != '': |
| 62 | llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) |
| 63 | updates = VideoLLaMAConfig(**FLAGS.llama) |
| 64 | llama_config.update(dict( |
| 65 | scan_attention=updates.scan_attention, |
| 66 | scan_mlp=updates.scan_mlp, |
| 67 | scan_query_chunk_size=updates.scan_query_chunk_size, |
| 68 | scan_key_chunk_size=updates.scan_key_chunk_size, |
| 69 | scan_mlp_chunk_size=updates.scan_mlp_chunk_size, |
| 70 | scan_layers=updates.scan_layers, |
| 71 | param_scan_axis=updates.param_scan_axis, |
| 72 | )) |
| 73 | else: |
| 74 | llama_config = VideoLLaMAConfig(**FLAGS.llama) |
| 75 | |
| 76 | if FLAGS.update_llama_config != '': |
| 77 | llama_config.update(dict(eval(FLAGS.update_llama_config))) |
| 78 | |
| 79 | llama_config.update(dict( |
| 80 | bos_token_id=tokenizer.bos_token_id, |
| 81 | eos_token_id=tokenizer.eos_token_id, |
| 82 | )) |
| 83 | llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) |
| 84 | |
| 85 | with jax.default_device(jax.devices("cpu")[0]): |
| 86 | _, params = StreamingCheckpointer.load_trainstate_checkpoint( |
| 87 | FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 |
| 88 | ) |
| 89 | model = FlaxVideoLLaMAForCausalLM( |
| 90 | llama_config, |
| 91 | input_shape=(512, 8192), |
| 92 | seed=FLAGS.seed, |
| 93 | _do_init=False, |
| 94 | dtype=get_float_dtype_by_name(FLAGS.dtype), |
| 95 | ) |
| 96 | model_ps = match_partition_rules( |
| 97 | VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params |
| 98 | ) |
| 99 | shard_fns, _ = make_shard_and_gather_fns( |
| 100 | model_ps, get_float_dtype_by_name(FLAGS.dtype) |
| 101 | ) |
nothing calls this directly
no test coverage detected