MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / main

Function main

lwm/vision_generation.py:44–250  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

42
43
44def main(argv):
45 assert FLAGS.output_file != ''
46 if FLAGS.output_file.endswith('mp4'):
47 assert FLAGS.n_frames > 1
48 elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'):
49 assert FLAGS.n_frames == 1
50 else:
51 raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}")
52
53 JaxDistributedConfig.initialize(FLAGS.jax_distributed)
54 set_random_seed(FLAGS.seed)
55
56 tokens_per_frame = 257
57 vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
58 mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
59 tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
60 prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
61 if FLAGS.load_llama_config != '':
62 llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
63 updates = VideoLLaMAConfig(**FLAGS.llama)
64 llama_config.update(dict(
65 scan_attention=updates.scan_attention,
66 scan_mlp=updates.scan_mlp,
67 scan_query_chunk_size=updates.scan_query_chunk_size,
68 scan_key_chunk_size=updates.scan_key_chunk_size,
69 scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
70 scan_layers=updates.scan_layers,
71 param_scan_axis=updates.param_scan_axis,
72 ))
73 else:
74 llama_config = VideoLLaMAConfig(**FLAGS.llama)
75
76 if FLAGS.update_llama_config != '':
77 llama_config.update(dict(eval(FLAGS.update_llama_config)))
78
79 llama_config.update(dict(
80 bos_token_id=tokenizer.bos_token_id,
81 eos_token_id=tokenizer.eos_token_id,
82 ))
83 llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
84
85 with jax.default_device(jax.devices("cpu")[0]):
86 _, params = StreamingCheckpointer.load_trainstate_checkpoint(
87 FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
88 )
89 model = FlaxVideoLLaMAForCausalLM(
90 llama_config,
91 input_shape=(512, 8192),
92 seed=FLAGS.seed,
93 _do_init=False,
94 dtype=get_float_dtype_by_name(FLAGS.dtype),
95 )
96 model_ps = match_partition_rules(
97 VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params
98 )
99 shard_fns, _ = make_shard_and_gather_fns(
100 model_ps, get_float_dtype_by_name(FLAGS.dtype)
101 )

Callers

nothing calls this directly

Calls 8

VQGANClass · 0.90
VideoLLaMAConfigClass · 0.90
generate_first_frameFunction · 0.85
generate_video_predFunction · 0.85
get_jax_meshMethod · 0.80
load_configMethod · 0.45
get_partition_rulesMethod · 0.45

Tested by

no test coverage detected