Setup model and optimizer.
(args)
| 47 | |
| 48 | |
| 49 | def setup_model(args): |
| 50 | """Setup model and optimizer.""" |
| 51 | |
| 52 | model = get_model(args) |
| 53 | |
| 54 | if args.load is not None: |
| 55 | if args.deepspeed: |
| 56 | iteration, release, success = get_checkpoint_iteration(args) |
| 57 | path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt") |
| 58 | print('current device:', torch.cuda.current_device()) |
| 59 | checkpoint = torch.load(path, map_location=torch.device('cpu')) |
| 60 | model.load_state_dict(checkpoint["module"]) |
| 61 | print(f"Load model file {path}") |
| 62 | else: |
| 63 | _ = load_checkpoint( |
| 64 | model, None, None, args, load_optimizer_states=False) |
| 65 | |
| 66 | return model |
| 67 | |
| 68 | def _parse_and_to_tensor(text, img_size=256, query_template='{}'): |
| 69 | tokenizer = get_tokenizer() |
no test coverage detected