()
| 139 | |
| 140 | |
| 141 | def main(): |
| 142 | torch.manual_seed(12345) |
| 143 | parser = argparse.ArgumentParser() |
| 144 | parser.add_argument('--tp-size', type=int, default=1) |
| 145 | parser.add_argument('--out-dir', type=Path, required=True) |
| 146 | parser.add_argument('--num-loras', type=int, default=1) |
| 147 | parser.add_argument('--num-layers', type=int, default=2) |
| 148 | parser.add_argument('--adapter-size', type=int, default=8) |
| 149 | parser.add_argument('--hidden-size', type=int, default=16) |
| 150 | parser.add_argument('--mlp-hidden-size', type=int, default=32) |
| 151 | parser.add_argument('--no-generate-cache-pages', |
| 152 | action='store_true', |
| 153 | default=False) |
| 154 | parser.add_argument( |
| 155 | '--config-ids-filter', |
| 156 | type=str, |
| 157 | default=None, |
| 158 | help= |
| 159 | "Comma separated list of ids to include. For example, use --config-ids-filter=0 for attn_qkv only." |
| 160 | ) |
| 161 | parser.add_argument('--target-file-name', type=str, default="target.npy") |
| 162 | parser.add_argument('--config-file-name', type=str, default="config.npy") |
| 163 | |
| 164 | args = parser.parse_args() |
| 165 | |
| 166 | num_layers = args.num_layers |
| 167 | adapter_size = args.adapter_size |
| 168 | hidden_size = args.hidden_size |
| 169 | mlp_hidden_size = args.mlp_hidden_size |
| 170 | configs = [ |
| 171 | (0, num_layers, adapter_size, hidden_size, 3 * hidden_size), # attn_qkv |
| 172 | (1, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_q |
| 173 | (2, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_k |
| 174 | (3, num_layers, adapter_size // 2, hidden_size, hidden_size), # attn_v |
| 175 | (4, num_layers, adapter_size, hidden_size, hidden_size), # attn_dense |
| 176 | (5, num_layers, adapter_size, hidden_size, |
| 177 | mlp_hidden_size), # mlp_h_to_4h |
| 178 | (6, num_layers, adapter_size, mlp_hidden_size, |
| 179 | hidden_size), # mlp_4h_to_h |
| 180 | (7, num_layers, adapter_size, hidden_size, mlp_hidden_size), # mlp_gate |
| 181 | (8, num_layers, adapter_size, hidden_size, |
| 182 | 3 * hidden_size), # cross_attn_qkv |
| 183 | (9, num_layers, adapter_size // 2, hidden_size, |
| 184 | hidden_size), # cross_attn_q |
| 185 | (10, num_layers, adapter_size // 2, hidden_size, |
| 186 | hidden_size), # cross_attn_k |
| 187 | (11, num_layers, adapter_size // 2, hidden_size, |
| 188 | hidden_size), # cross_attn_v |
| 189 | (12, num_layers, adapter_size, hidden_size, |
| 190 | hidden_size), # cross_attn_dense |
| 191 | ] |
| 192 | if args.config_ids_filter: |
| 193 | config_ids_filter = [int(x) for x in args.config_ids_filter.split(",")] |
| 194 | configs = [c for c in configs if c[0] in config_ids_filter] |
| 195 | |
| 196 | for lora_idx in range(args.num_loras): |
| 197 | for is_dora in [None, False, True]: |
| 198 | all_source = [] |
no test coverage detected