()
| 1 | def test_model_arch(): |
| 2 | import random |
| 3 | from transformers import PretrainedConfig |
| 4 | |
| 5 | from swift.model import MODEL_MAPPING |
| 6 | from swift.utils import JsonlWriter, safe_snapshot_download |
| 7 | jsonl_writer = JsonlWriter('model_arch.jsonl') |
| 8 | for i, (model_type, model_meta) in enumerate(MODEL_MAPPING.items()): |
| 9 | if i < 0: |
| 10 | continue |
| 11 | arch_list = model_meta.architectures |
| 12 | for model_group in model_meta.model_groups: |
| 13 | model = random.choice(model_group.models).ms_model_id |
| 14 | config_dict = None |
| 15 | try: |
| 16 | model_dir = safe_snapshot_download(model, download_model=False) |
| 17 | config_dict = PretrainedConfig.get_config_dict(model_dir)[0] |
| 18 | except Exception: |
| 19 | pass |
| 20 | finally: |
| 21 | msg = None |
| 22 | if config_dict: |
| 23 | arch = config_dict.get('architectures') |
| 24 | if arch and arch[0] not in arch_list: |
| 25 | msg = { |
| 26 | 'model_type': model_type, |
| 27 | 'model': model, |
| 28 | 'config_arch': arch, |
| 29 | 'architectures': arch_list |
| 30 | } |
| 31 | elif not arch and arch_list: |
| 32 | msg = { |
| 33 | 'model_type': model_type, |
| 34 | 'model': model, |
| 35 | 'config_arch': arch, |
| 36 | 'architectures': arch_list |
| 37 | } |
| 38 | else: |
| 39 | msg = {'msg': 'error', 'model_type': model_type, 'model': model, 'arch_list': arch_list} |
| 40 | if msg: |
| 41 | jsonl_writer.append(msg) |
| 42 | |
| 43 | |
| 44 | if __name__ == '__main__': |
no test coverage detected