()
| 120 | |
| 121 | |
| 122 | def main(): |
| 123 | # parse options |
| 124 | cfg = parse_args(phase="demo") # parse config file |
| 125 | cfg.FOLDER = cfg.TEST.FOLDER |
| 126 | |
| 127 | # create logger |
| 128 | logger = create_logger(cfg, phase="test") |
| 129 | |
| 130 | task = cfg.DEMO.TASK |
| 131 | text = None |
| 132 | |
| 133 | output_dir = Path( |
| 134 | os.path.join(cfg.FOLDER, str(cfg.model.model_type), str(cfg.NAME), |
| 135 | "samples_" + cfg.TIME)) |
| 136 | output_dir.mkdir(parents=True, exist_ok=True) |
| 137 | |
| 138 | logger.info(OmegaConf.to_yaml(cfg)) |
| 139 | |
| 140 | # set seed |
| 141 | pl.seed_everything(cfg.SEED_VALUE) |
| 142 | |
| 143 | # gpu setting |
| 144 | if cfg.ACCELERATOR == "gpu": |
| 145 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( |
| 146 | str(x) for x in cfg.DEVICE) |
| 147 | device = torch.device("cuda") |
| 148 | |
| 149 | # Dataset |
| 150 | datamodule = build_data(cfg) |
| 151 | logger.info("datasets module {} initialized".format("".join( |
| 152 | cfg.DATASET.target.split('.')[-2]))) |
| 153 | |
| 154 | # create model |
| 155 | total_time = time.time() |
| 156 | model = build_model(cfg, datamodule) |
| 157 | logger.info("model {} loaded".format(cfg.model.target)) |
| 158 | |
| 159 | # loading state dict |
| 160 | if cfg.TEST.CHECKPOINTS: |
| 161 | logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) |
| 162 | state_dict = torch.load(cfg.TEST.CHECKPOINTS, |
| 163 | map_location="cpu")["state_dict"] |
| 164 | model.load_state_dict(state_dict) |
| 165 | else: |
| 166 | logger.warning( |
| 167 | "No checkpoints provided, using random initialized model") |
| 168 | |
| 169 | model.to(device) |
| 170 | |
| 171 | if cfg.DEMO.EXAMPLE: |
| 172 | # Check txt file input |
| 173 | # load txt |
| 174 | return_dict = load_example_input(cfg.DEMO.EXAMPLE, task, model) |
| 175 | text, in_joints = return_dict['text'], return_dict['motion_joints'] |
| 176 | |
| 177 | batch_size = 64 |
| 178 | if text: |
| 179 | for b in tqdm(range(len(text) // batch_size + 1)): |
no test coverage detected