()
| 407 | |
| 408 | |
| 409 | def generate_reports(): |
| 410 | outputs = [] |
| 411 | for dirs, _, files in os.walk('./experiment'): |
| 412 | for file in files: |
| 413 | abs_file = os.path.join(dirs, file) |
| 414 | if not abs_file.endswith('.json') or 'ipynb' in abs_file: |
| 415 | continue |
| 416 | |
| 417 | outputs.append(parse_output(abs_file)) |
| 418 | |
| 419 | all_groups = set([output.group for output in outputs]) |
| 420 | for group in all_groups: |
| 421 | group_outputs = [output for output in outputs if output.group == group] |
| 422 | print(f'=================Printing the sft cmd result of exp {group}==================\n\n') |
| 423 | print(generate_sft_report([output for output in group_outputs if output.cmd in ('sft', 'eval')])) |
| 424 | # print(f'=================Printing the dpo result of exp {group}==================') |
| 425 | # print(generate_dpo_report([output for output in outputs if output.cmd == 'dpo'])) |
| 426 | print(f'=================Printing the export cmd result of exp {group}==================\n\n') |
| 427 | print(generate_export_report([output for output in group_outputs if output.cmd == 'export'])) |
| 428 | print('=================Printing done==================\n\n') |
| 429 | |
| 430 | |
| 431 | if __name__ == '__main__': |
no test coverage detected