MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / main

Function main

lwm/train.py:59–391  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

57
58
59def main(argv):
60 JaxDistributedConfig.initialize(FLAGS.jax_distributed)
61 variant = tux.get_user_flags(FLAGS, FLAGS_DEF)
62 flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
63
64 logger = tux.WandBLogger(
65 config=FLAGS.logger,
66 variant=variant,
67 enable=FLAGS.log_all_worker or (jax.process_index() == 0),
68 )
69 set_random_seed(FLAGS.seed)
70
71 if jax.process_index() == 0:
72 output_dir = logger.output_dir
73 else:
74 output_dir = os.path.join(logger.output_dir, logger.experiment_id)
75
76 if FLAGS.modality == 'text':
77 config_cls = LLaMAConfig
78 llama_cls = FlaxLLaMAForCausalLMModule
79 elif FLAGS.modality == 'vision,text':
80 config_cls = VideoLLaMAConfig
81 llama_cls = FlaxVideoLLaMAForCausalLMModule
82 else:
83 raise ValueError(f"Unsupported modality: {FLAGS.modality}")
84
85 mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim)
86 node_info = config_cls.get_ranks_and_size(mesh)
87
88 tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
89 dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info)
90 if FLAGS.autoresume and tux.check_exists(output_dir):
91 logging.info('Found existing output. Resuming dataset from latest checkpoint...')
92 resume_path = f"{output_dir}/dataset.pkl"
93 dataset.load_state_dict(tux.load_pickle(resume_path))
94 elif FLAGS.load_dataset_state != '':
95 dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state))
96
97 if FLAGS.eval_steps > 0:
98 eval_dataset = DatasetFactory.load_dataset(
99 FLAGS.eval_dataset, dataset.tokenizer
100 )
101 eval_iterator = iter(eval_dataset)
102
103 seq_length = dataset.seq_length
104
105 if FLAGS.load_llama_config != '':
106 llama_config = config_cls.load_config(FLAGS.load_llama_config)
107 updates = config_cls(**FLAGS.llama)
108 llama_config.update(dict(
109 scan_attention=updates.scan_attention,
110 scan_mlp=updates.scan_mlp,
111 scan_query_chunk_size=updates.scan_query_chunk_size,
112 scan_key_chunk_size=updates.scan_key_chunk_size,
113 scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
114 scan_layers=updates.scan_layers,
115 param_scan_axis=updates.param_scan_axis,
116 ))

Callers

nothing calls this directly

Calls 8

save_checkpointFunction · 0.85
get_jax_meshMethod · 0.80
get_ranks_and_sizeMethod · 0.80
load_datasetMethod · 0.80
load_state_dictMethod · 0.45
load_configMethod · 0.45
get_partition_rulesMethod · 0.45

Tested by

no test coverage detected