(model, data, epoch, args, tb_writer=None, extra_suffix="")
| 207 | |
| 208 | |
| 209 | def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): |
| 210 | metrics = {} |
| 211 | if not args.parallel_eval: |
| 212 | if not is_master(args): |
| 213 | return metrics |
| 214 | device = torch.device(args.device) |
| 215 | model.eval() |
| 216 | |
| 217 | # CHANGE |
| 218 | # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) |
| 219 | # metrics.update(zero_shot_metrics) |
| 220 | if is_master(args): |
| 221 | print("Evaluating...") |
| 222 | metric_names = args.lp_metrics.split(",") |
| 223 | eval_tool = LPMetrics(metric_names=metric_names) |
| 224 | |
| 225 | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| 226 | if "val" in data and ( |
| 227 | args.val_frequency |
| 228 | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) |
| 229 | ): |
| 230 | if args.parallel_eval: |
| 231 | dataloader, sampler = data["val"].dataloader, data["val"].sampler |
| 232 | if args.distributed and sampler is not None: |
| 233 | sampler.set_epoch(epoch) |
| 234 | samples_per_val = dataloader.num_samples |
| 235 | else: |
| 236 | dataloader = data["val"].dataloader |
| 237 | num_samples = 0 |
| 238 | samples_per_val = dataloader.num_samples |
| 239 | |
| 240 | eval_info = {"pred": [], "target": []} |
| 241 | with torch.no_grad(): |
| 242 | for i, batch in enumerate(dataloader): |
| 243 | audio = batch # contains mel_spec, wavform, and longer list |
| 244 | class_label = batch["class_label"] |
| 245 | |
| 246 | # audio = audio.to(device=device, non_blocking=True) |
| 247 | class_label = class_label.to(device=device, non_blocking=True) |
| 248 | |
| 249 | with autocast(): |
| 250 | pred = model(audio, device=device) |
| 251 | if args.parallel_eval: |
| 252 | pred, class_label = lp_gather_features( |
| 253 | pred, class_label, args.world_size, args.horovod |
| 254 | ) |
| 255 | eval_info["pred"].append(pred) |
| 256 | eval_info["target"].append(class_label) |
| 257 | |
| 258 | num_samples += class_label.shape[0] |
| 259 | |
| 260 | if (i % 100) == 0: # and i != 0: |
| 261 | logging.info( |
| 262 | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" |
| 263 | ) |
| 264 | |
| 265 | if is_master(args): |
| 266 | eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() |
no test coverage detected