(argv)
| 57 | |
| 58 | |
| 59 | def main(argv): |
| 60 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) |
| 61 | variant = tux.get_user_flags(FLAGS, FLAGS_DEF) |
| 62 | flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF) |
| 63 | |
| 64 | logger = tux.WandBLogger( |
| 65 | config=FLAGS.logger, |
| 66 | variant=variant, |
| 67 | enable=FLAGS.log_all_worker or (jax.process_index() == 0), |
| 68 | ) |
| 69 | set_random_seed(FLAGS.seed) |
| 70 | |
| 71 | if jax.process_index() == 0: |
| 72 | output_dir = logger.output_dir |
| 73 | else: |
| 74 | output_dir = os.path.join(logger.output_dir, logger.experiment_id) |
| 75 | |
| 76 | if FLAGS.modality == 'text': |
| 77 | config_cls = LLaMAConfig |
| 78 | llama_cls = FlaxLLaMAForCausalLMModule |
| 79 | elif FLAGS.modality == 'vision,text': |
| 80 | config_cls = VideoLLaMAConfig |
| 81 | llama_cls = FlaxVideoLLaMAForCausalLMModule |
| 82 | else: |
| 83 | raise ValueError(f"Unsupported modality: {FLAGS.modality}") |
| 84 | |
| 85 | mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim) |
| 86 | node_info = config_cls.get_ranks_and_size(mesh) |
| 87 | |
| 88 | tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) |
| 89 | dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info) |
| 90 | if FLAGS.autoresume and tux.check_exists(output_dir): |
| 91 | logging.info('Found existing output. Resuming dataset from latest checkpoint...') |
| 92 | resume_path = f"{output_dir}/dataset.pkl" |
| 93 | dataset.load_state_dict(tux.load_pickle(resume_path)) |
| 94 | elif FLAGS.load_dataset_state != '': |
| 95 | dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state)) |
| 96 | |
| 97 | if FLAGS.eval_steps > 0: |
| 98 | eval_dataset = DatasetFactory.load_dataset( |
| 99 | FLAGS.eval_dataset, dataset.tokenizer |
| 100 | ) |
| 101 | eval_iterator = iter(eval_dataset) |
| 102 | |
| 103 | seq_length = dataset.seq_length |
| 104 | |
| 105 | if FLAGS.load_llama_config != '': |
| 106 | llama_config = config_cls.load_config(FLAGS.load_llama_config) |
| 107 | updates = config_cls(**FLAGS.llama) |
| 108 | llama_config.update(dict( |
| 109 | scan_attention=updates.scan_attention, |
| 110 | scan_mlp=updates.scan_mlp, |
| 111 | scan_query_chunk_size=updates.scan_query_chunk_size, |
| 112 | scan_key_chunk_size=updates.scan_key_chunk_size, |
| 113 | scan_mlp_chunk_size=updates.scan_mlp_chunk_size, |
| 114 | scan_layers=updates.scan_layers, |
| 115 | param_scan_axis=updates.param_scan_axis, |
| 116 | )) |
nothing calls this directly
no test coverage detected