(self)
| 50 | args: args_class |
| 51 | |
| 52 | def run(self): |
| 53 | lang = os.environ.get("SWIFT_UI_LANG") or self.args.lang |
| 54 | share_env = os.environ.get("WEBUI_SHARE") |
| 55 | share = strtobool(share_env) if share_env else self.args.share |
| 56 | server = os.environ.get("WEBUI_SERVER") or self.args.server_name |
| 57 | port_env = os.environ.get("WEBUI_PORT") |
| 58 | port = int(port_env) if port_env else self.args.server_port |
| 59 | LLMTrain.set_lang(lang) |
| 60 | LLMRLHF.set_lang(lang) |
| 61 | LLMGRPO.set_lang(lang) |
| 62 | LLMInfer.set_lang(lang) |
| 63 | LLMExport.set_lang(lang) |
| 64 | LLMEval.set_lang(lang) |
| 65 | LLMSample.set_lang(lang) |
| 66 | with gr.Blocks(title="SWIFT WebUI", theme=gr.themes.Base()) as app: |
| 67 | try: |
| 68 | _version = swift.__version__ |
| 69 | except AttributeError: |
| 70 | _version = "" |
| 71 | gr.HTML( |
| 72 | f"<h1><center>{locale_dict['title'][lang]}({_version})</center></h1>" |
| 73 | ) |
| 74 | gr.HTML(f"<h3><center>{locale_dict['sub_title'][lang]}</center></h3>") |
| 75 | with gr.Tabs(): |
| 76 | LLMTrain.build_ui(LLMTrain) |
| 77 | LLMRLHF.build_ui(LLMRLHF) |
| 78 | LLMGRPO.build_ui(LLMGRPO) |
| 79 | LLMInfer.build_ui(LLMInfer) |
| 80 | LLMExport.build_ui(LLMExport) |
| 81 | LLMEval.build_ui(LLMEval) |
| 82 | LLMSample.build_ui(LLMSample) |
| 83 | |
| 84 | concurrent = {} |
| 85 | if version.parse(gr.__version__) < version.parse("4.0.0"): |
| 86 | concurrent = {"concurrency_count": 5} |
| 87 | app.load( |
| 88 | partial(LLMTrain.update_input_model, arg_cls=RLHFArguments), |
| 89 | inputs=[LLMTrain.element("model")], |
| 90 | outputs=[LLMTrain.element("train_record")] |
| 91 | + list(LLMTrain.valid_elements().values()), |
| 92 | ) |
| 93 | app.load( |
| 94 | partial(LLMRLHF.update_input_model, arg_cls=RLHFArguments), |
| 95 | inputs=[LLMRLHF.element("model")], |
| 96 | outputs=[LLMRLHF.element("train_record")] |
| 97 | + list(LLMRLHF.valid_elements().values()), |
| 98 | ) |
| 99 | app.load( |
| 100 | partial(LLMGRPO.update_input_model, arg_cls=RLHFArguments), |
| 101 | inputs=[LLMGRPO.element("model")], |
| 102 | outputs=[LLMGRPO.element("train_record")] |
| 103 | + list(LLMGRPO.valid_elements().values()), |
| 104 | ) |
| 105 | app.load( |
| 106 | partial( |
| 107 | LLMInfer.update_input_model, |
| 108 | arg_cls=DeployArguments, |
| 109 | has_record=False, |
no test coverage detected